Skip to content

[XPU] Support empty_cache on XPUs #9789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions llm/alignment/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
speed_metrics,
)
from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer
from paddlenlp.utils import empty_device_cache


class StepTrainer(Trainer):
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def gen_epoch_data():
ptx_batches = [None for _ in range(len(rl_batches))]
self.timers and self.timers("ptx-batch").stop()

paddle.device.cuda.empty_cache()
empty_device_cache()

self.set_train()
for _ in range(self.args.update_iters):
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def train(

# ##### model and optimizer related setting #####
policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint)
paddle.device.cuda.empty_cache()
empty_device_cache()

# ##### traing statistic logging #####
# Number of trainable parameters only account for policy_model
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def train(
# with self.enable(self.value_trainer.optimizer):
with self.enable(): # put value optimizer guard in rl_step
rl_info = self.rl_step(rl_batch)
paddle.device.cuda.empty_cache()
empty_device_cache()
self.timers and self.timers("rl_step").stop()

if self.use_ptx:
Expand All @@ -1224,7 +1225,7 @@ def train(
ptx_info = self.ptx_step(ptx_batch)
rl_info.update(ptx_info)
self.timers and self.timers("ptx_step").stop()
paddle.device.cuda.empty_cache()
empty_device_cache()

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/quantization/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.nn.quant import weight_quantize

from ..utils.log import logger
from ..utils.memory_utils import empty_device_cache

Check warning on line 26 in paddlenlp/quantization/quantization_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/quantization/quantization_utils.py#L26

Added line #L26 was not covered by tests
from .quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
Expand Down Expand Up @@ -150,7 +151,7 @@
state_dict.update(qlora_state_dict)
del target_weight
gc.collect()
paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 154 in paddlenlp/quantization/quantization_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/quantization/quantization_utils.py#L154

Added line #L154 was not covered by tests
return state_dict


Expand Down
20 changes: 10 additions & 10 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
unwrap_model,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils import infohub
from paddlenlp.utils import empty_device_cache, infohub
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
MAX_QUANTIZATION_TIMES,
Expand Down Expand Up @@ -158,7 +158,7 @@
if self.args.should_save:
save_model_config(model_to_save, save_directory)

paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 161 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L161

Added line #L161 was not covered by tests

if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save:
world_size = paddle.distributed.get_world_size()
Expand Down Expand Up @@ -195,7 +195,7 @@
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)

def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir):
paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 198 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L198

Added line #L198 was not covered by tests

# gather global master_weights status.
global_master_weights = reduce_master_weights_status(master_weights is not None)
Expand Down Expand Up @@ -373,7 +373,7 @@
optim_state_dict, shard_optim_file, sharded_optim_index = results[0]
master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1]

paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 376 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L376

Added line #L376 was not covered by tests
save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
Expand Down Expand Up @@ -506,7 +506,7 @@
Returns:
tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name.
"""
paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 509 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L509

Added line #L509 was not covered by tests
assert hasattr(model_to_save, "config")

state_dict = get_expected_state_dict(model_to_save, concat_additional_adapter=True)
Expand Down Expand Up @@ -558,7 +558,7 @@
elif isinstance(model_to_save, PrefixModelForCausalLM):
sharded_index["type"] = "ptuning"

paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 561 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L561

Added line #L561 was not covered by tests

return state_dict, shard_file, sharded_index

Expand All @@ -576,7 +576,7 @@
optimizer (Optimizer): optimizer to save.
safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False.
"""
paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 579 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L579

Added line #L579 was not covered by tests

# gather global master_weights status.
global_master_weights = reduce_master_weights_status(master_weights is not None)
Expand Down Expand Up @@ -643,7 +643,7 @@
filter_optim_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 646 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L646

Added line #L646 was not covered by tests

if master_weights is not None:
logger.info("Unified master weight tensor parallel in shards")
Expand All @@ -653,7 +653,7 @@
filter_master_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 656 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L656

Added line #L656 was not covered by tests

# build index json file
index_optimizer_file, index_master_weight_file = {}, {}
Expand Down Expand Up @@ -704,7 +704,7 @@
else:
sharded_optim_index["master_weights"] = False

paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 707 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L707

Added line #L707 was not covered by tests
if master_weights is None:
return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)]
else:
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SimpleInfclLoss,
)
from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient
from paddlenlp.utils import empty_device_cache

__all__ = ["EmbeddingTrainer"]

Expand Down Expand Up @@ -63,7 +64,7 @@
def clear_memory(self):
self.accum_q_features.clear()
self.accum_p_features.clear()
paddle.device.cuda.empty_cache()
empty_device_cache()

Check warning on line 67 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L67

Added line #L67 was not covered by tests

def clear_state(self):
self.accum_data.clear()
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .import_utils import *
from .infohub import infohub
from .initializer import to
from .memory_utils import empty_device_cache
from .optimizer import *
from .serialization import load_torch

Expand Down
29 changes: 29 additions & 0 deletions paddlenlp/utils/memory_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# coding:utf-8
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

__all__ = [
"empty_device_cache",
]


def empty_device_cache():
if paddle.device.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
elif paddle.device.is_compiled_with_xpu():
paddle.device.xpu.empty_cache()

Check warning on line 27 in paddlenlp/utils/memory_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/memory_utils.py#L24-L27

Added lines #L24 - L27 were not covered by tests
else:
pass

Check warning on line 29 in paddlenlp/utils/memory_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/memory_utils.py#L29

Added line #L29 was not covered by tests
9 changes: 5 additions & 4 deletions slm/examples/RLHF/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
speed_metrics,
)
from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer
from paddlenlp.utils import empty_device_cache


class StepTrainer(Trainer):
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def gen_epoch_data():
ptx_batches = [None for _ in range(len(rl_batches))]
self.timers and self.timers("ptx-batch").stop()

paddle.device.cuda.empty_cache()
empty_device_cache()

self.set_train()
for _ in range(self.args.update_iters):
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def train(

# ##### model and optimizer related setting #####
policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint)
paddle.device.cuda.empty_cache()
empty_device_cache()

# ##### traing statistic logging #####
# Number of trainable parameters only account for policy_model
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def train(
# with self.enable(self.value_trainer.optimizer):
with self.enable(): # put value optimizer guard in rl_step
rl_info = self.rl_step(rl_batch)
paddle.device.cuda.empty_cache()
empty_device_cache()
self.timers and self.timers("rl_step").stop()

if self.use_ptx:
Expand All @@ -1224,7 +1225,7 @@ def train(
ptx_info = self.ptx_step(ptx_batch)
rl_info.update(ptx_info)
self.timers and self.timers("ptx_step").stop()
paddle.device.cuda.empty_cache()
empty_device_cache()

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
Expand Down