diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index be1af93c50fd..8993a276eb56 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -39,6 +39,8 @@ import paddle.distributed as dist import paddle.nn as nn from packaging import version +from paddle import framework +from paddle.base import core from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( HybridParallelOptimizer, @@ -1256,6 +1258,20 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate())) logs["global_step"] = int(self.state.global_step) + divisor = 2**30 + # TODO(@gexiao): replace these codes with unified APIs in Paddle + current_device = framework._current_expected_place_() + if str(current_device) != "Place(cpu)": + device_id = current_device.get_device_id() + current_memory_allocated = core.device_memory_stat_current_value("Allocated", device_id) + current_memory_reserved = core.device_memory_stat_current_value("Reserved", device_id) + max_memory_allocated = core.device_memory_stat_peak_value("Allocated", device_id) + max_memory_reserved = core.device_memory_stat_peak_value("Reserved", device_id) + logs["current_memory_allocated"] = current_memory_allocated / divisor + logs["current_memory_reserved"] = current_memory_reserved / divisor + logs["max_memory_allocated"] = max_memory_allocated / divisor + logs["max_memory_reserved"] = max_memory_reserved / divisor + total_train_batch_size = ( self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size ) @@ -1586,8 +1602,6 @@ def _load_rng_state(self, checkpoint): random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) - core = paddle.framework.core - core.default_cpu_generator().set_state(checkpoint_rng_state["cpu"]) if core.is_compiled_with_cuda(): if not len(checkpoint_rng_state["cuda"]) == core.get_cuda_device_count():