From d407728dc1fd968cb6391d69de4db03de8023e9d Mon Sep 17 00:00:00 2001 From: will-jl944 Date: Thu, 16 Jan 2025 15:18:05 +0800 Subject: [PATCH 1/2] [XPU] Support empty_cache on XPUs --- llm/alignment/ppo/ppo_trainer.py | 9 +++--- paddlenlp/quantization/quantization_utils.py | 3 +- .../unified_checkpoint/unified_checkpoint.py | 20 ++++++------- paddlenlp/trl/embedding_trainer.py | 3 +- paddlenlp/utils/__init__.py | 1 + paddlenlp/utils/memory_utils.py | 29 +++++++++++++++++++ slm/examples/RLHF/ppo_trainer.py | 9 +++--- 7 files changed, 54 insertions(+), 20 deletions(-) create mode 100644 paddlenlp/utils/memory_utils.py diff --git a/llm/alignment/ppo/ppo_trainer.py b/llm/alignment/ppo/ppo_trainer.py index c2c72d6c5cd1..bdec462411e0 100644 --- a/llm/alignment/ppo/ppo_trainer.py +++ b/llm/alignment/ppo/ppo_trainer.py @@ -66,6 +66,7 @@ speed_metrics, ) from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer +from paddlenlp.utils import empty_device_cache class StepTrainer(Trainer): @@ -1032,7 +1033,7 @@ def gen_epoch_data(): ptx_batches = [None for _ in range(len(rl_batches))] self.timers and self.timers("ptx-batch").stop() - paddle.device.cuda.empty_cache() + empty_device_cache() self.set_train() for _ in range(self.args.update_iters): @@ -1152,7 +1153,7 @@ def train( # ##### model and optimizer related setting ##### policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint) - paddle.device.cuda.empty_cache() + empty_device_cache() # ##### traing statistic logging ##### # Number of trainable parameters only account for policy_model @@ -1208,7 +1209,7 @@ def train( # with self.enable(self.value_trainer.optimizer): with self.enable(): # put value optimizer guard in rl_step rl_info = self.rl_step(rl_batch) - paddle.device.cuda.empty_cache() + empty_device_cache() self.timers and self.timers("rl_step").stop() if self.use_ptx: @@ -1224,7 +1225,7 @@ def train( ptx_info = self.ptx_step(ptx_batch) rl_info.update(ptx_info) self.timers and self.timers("ptx_step").stop() - paddle.device.cuda.empty_cache() + empty_device_cache() self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch diff --git a/paddlenlp/quantization/quantization_utils.py b/paddlenlp/quantization/quantization_utils.py index fe46efd2a2fa..a12bebd89d36 100644 --- a/paddlenlp/quantization/quantization_utils.py +++ b/paddlenlp/quantization/quantization_utils.py @@ -23,6 +23,7 @@ from paddle.nn.quant import weight_quantize from ..utils.log import logger +from ..utils.memory_utils import empty_device_cache from .quantization_linear import ( ColumnParallelQuantizationLinear, QuantizationLinear, @@ -150,7 +151,7 @@ def convert_to_quantize_state_dict_without_check(state_dict, quantization_linear state_dict.update(qlora_state_dict) del target_weight gc.collect() - paddle.device.cuda.empty_cache() + empty_device_cache() return state_dict diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 0cb38bec94bb..dba33fbe1a7a 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -29,7 +29,7 @@ unwrap_model, ) from paddlenlp.transformers.utils import dtype_byte_size -from paddlenlp.utils import infohub +from paddlenlp.utils import empty_device_cache, infohub from paddlenlp.utils.env import ( LORA_WEIGHTS_NAME, MAX_QUANTIZATION_TIMES, @@ -158,7 +158,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None) if self.args.should_save: save_model_config(model_to_save, save_directory) - paddle.device.cuda.empty_cache() + empty_device_cache() if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save: world_size = paddle.distributed.get_world_size() @@ -195,7 +195,7 @@ def load_unified_checkpoint(self, model, resume_from_checkpoint: str): load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir): - paddle.device.cuda.empty_cache() + empty_device_cache() # gather global master_weights status. global_master_weights = reduce_master_weights_status(master_weights is not None) @@ -373,7 +373,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): optim_state_dict, shard_optim_file, sharded_optim_index = results[0] master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1] - paddle.device.cuda.empty_cache() + empty_device_cache() save_directory = output_dir os.makedirs(save_directory, exist_ok=True) if signal_dir is not None: @@ -506,7 +506,7 @@ def unified_checkpoint_into_shards( Returns: tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name. """ - paddle.device.cuda.empty_cache() + empty_device_cache() assert hasattr(model_to_save, "config") state_dict = get_expected_state_dict(model_to_save, concat_additional_adapter=True) @@ -558,7 +558,7 @@ def unified_checkpoint_into_shards( elif isinstance(model_to_save, PrefixModelForCausalLM): sharded_index["type"] = "ptuning" - paddle.device.cuda.empty_cache() + empty_device_cache() return state_dict, shard_file, sharded_index @@ -576,7 +576,7 @@ def unified_optimizer_into_shards( optimizer (Optimizer): optimizer to save. safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False. """ - paddle.device.cuda.empty_cache() + empty_device_cache() # gather global master_weights status. global_master_weights = reduce_master_weights_status(master_weights is not None) @@ -643,7 +643,7 @@ def unified_optimizer_into_shards( filter_optim_keys, state_dict if args.use_expert_parallel else None, ) - paddle.device.cuda.empty_cache() + empty_device_cache() if master_weights is not None: logger.info("Unified master weight tensor parallel in shards") @@ -653,7 +653,7 @@ def unified_optimizer_into_shards( filter_master_keys, state_dict if args.use_expert_parallel else None, ) - paddle.device.cuda.empty_cache() + empty_device_cache() # build index json file index_optimizer_file, index_master_weight_file = {}, {} @@ -704,7 +704,7 @@ def unified_optimizer_into_shards( else: sharded_optim_index["master_weights"] = False - paddle.device.cuda.empty_cache() + empty_device_cache() if master_weights is None: return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)] else: diff --git a/paddlenlp/trl/embedding_trainer.py b/paddlenlp/trl/embedding_trainer.py index c50f19738bed..610db8151fd3 100644 --- a/paddlenlp/trl/embedding_trainer.py +++ b/paddlenlp/trl/embedding_trainer.py @@ -26,6 +26,7 @@ SimpleInfclLoss, ) from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient +from paddlenlp.utils import empty_device_cache __all__ = ["EmbeddingTrainer"] @@ -63,7 +64,7 @@ def __init__(self, model_args, **kwargs): def clear_memory(self): self.accum_q_features.clear() self.accum_p_features.clear() - paddle.device.cuda.empty_cache() + empty_device_cache() def clear_state(self): self.accum_data.clear() diff --git a/paddlenlp/utils/__init__.py b/paddlenlp/utils/__init__.py index 3b5950b0d701..b4fb779c5abb 100644 --- a/paddlenlp/utils/__init__.py +++ b/paddlenlp/utils/__init__.py @@ -21,6 +21,7 @@ from .import_utils import * from .infohub import infohub from .initializer import to +from .memory_utils import empty_device_cache from .optimizer import * from .serialization import load_torch diff --git a/paddlenlp/utils/memory_utils.py b/paddlenlp/utils/memory_utils.py new file mode 100644 index 000000000000..4e649c1e9e56 --- /dev/null +++ b/paddlenlp/utils/memory_utils.py @@ -0,0 +1,29 @@ +# coding:utf-8 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +__all__ = [ + "empty_device_cache", +] + + +def empty_device_cache(): + if paddle.device.is_compiled_with_cuda(): + paddle.device.cuda.empty_cache() + elif paddle.device.is_compiled_with_xpu(): + paddle.device.xpu.empty_cache() + else: + pass diff --git a/slm/examples/RLHF/ppo_trainer.py b/slm/examples/RLHF/ppo_trainer.py index c2c72d6c5cd1..bdec462411e0 100644 --- a/slm/examples/RLHF/ppo_trainer.py +++ b/slm/examples/RLHF/ppo_trainer.py @@ -66,6 +66,7 @@ speed_metrics, ) from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer +from paddlenlp.utils import empty_device_cache class StepTrainer(Trainer): @@ -1032,7 +1033,7 @@ def gen_epoch_data(): ptx_batches = [None for _ in range(len(rl_batches))] self.timers and self.timers("ptx-batch").stop() - paddle.device.cuda.empty_cache() + empty_device_cache() self.set_train() for _ in range(self.args.update_iters): @@ -1152,7 +1153,7 @@ def train( # ##### model and optimizer related setting ##### policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint) - paddle.device.cuda.empty_cache() + empty_device_cache() # ##### traing statistic logging ##### # Number of trainable parameters only account for policy_model @@ -1208,7 +1209,7 @@ def train( # with self.enable(self.value_trainer.optimizer): with self.enable(): # put value optimizer guard in rl_step rl_info = self.rl_step(rl_batch) - paddle.device.cuda.empty_cache() + empty_device_cache() self.timers and self.timers("rl_step").stop() if self.use_ptx: @@ -1224,7 +1225,7 @@ def train( ptx_info = self.ptx_step(ptx_batch) rl_info.update(ptx_info) self.timers and self.timers("ptx_step").stop() - paddle.device.cuda.empty_cache() + empty_device_cache() self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch From 633f53669bea3ad13607204bbb970e6a5dcbdc2e Mon Sep 17 00:00:00 2001 From: will-jl944 Date: Tue, 21 Jan 2025 20:14:20 +0800 Subject: [PATCH 2/2] warn if current device doesn't support --- paddlenlp/utils/memory_utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/paddlenlp/utils/memory_utils.py b/paddlenlp/utils/memory_utils.py index 4e649c1e9e56..05e9a9fe3fef 100644 --- a/paddlenlp/utils/memory_utils.py +++ b/paddlenlp/utils/memory_utils.py @@ -15,15 +15,25 @@ import paddle +from .log import logger +from .tools import get_env_device + __all__ = [ "empty_device_cache", ] def empty_device_cache(): - if paddle.device.is_compiled_with_cuda(): + device = get_env_device() + if device == "gpu": paddle.device.cuda.empty_cache() - elif paddle.device.is_compiled_with_xpu(): + elif device == "xpu": paddle.device.xpu.empty_cache() else: - pass + if not getattr(empty_device_cache, "has_warned", False): + logger.warning( + "The current device ({}) does not support empty cache, calling empty_device_cache() will have no effect.".format( + device + ) + ) + setattr(empty_device_cache, "has_warned", True)