Skip to content

Commit beb433a

Browse files
authored
[LLM] add memory stats to logger of trainer (#8269)
1 parent c3ec984 commit beb433a

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import paddle.distributed as dist
4040
import paddle.nn as nn
4141
from packaging import version
42+
from paddle import framework
43+
from paddle.base import core
4244
from paddle.distributed import fleet
4345
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
4446
HybridParallelOptimizer,
@@ -1256,6 +1258,20 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
12561258
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
12571259
logs["global_step"] = int(self.state.global_step)
12581260

1261+
divisor = 2**30
1262+
# TODO(@gexiao): replace these codes with unified APIs in Paddle
1263+
current_device = framework._current_expected_place_()
1264+
if str(current_device) != "Place(cpu)":
1265+
device_id = current_device.get_device_id()
1266+
current_memory_allocated = core.device_memory_stat_current_value("Allocated", device_id)
1267+
current_memory_reserved = core.device_memory_stat_current_value("Reserved", device_id)
1268+
max_memory_allocated = core.device_memory_stat_peak_value("Allocated", device_id)
1269+
max_memory_reserved = core.device_memory_stat_peak_value("Reserved", device_id)
1270+
logs["current_memory_allocated"] = current_memory_allocated / divisor
1271+
logs["current_memory_reserved"] = current_memory_reserved / divisor
1272+
logs["max_memory_allocated"] = max_memory_allocated / divisor
1273+
logs["max_memory_reserved"] = max_memory_reserved / divisor
1274+
12591275
total_train_batch_size = (
12601276
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
12611277
)
@@ -1586,8 +1602,6 @@ def _load_rng_state(self, checkpoint):
15861602
random.setstate(checkpoint_rng_state["python"])
15871603
np.random.set_state(checkpoint_rng_state["numpy"])
15881604

1589-
core = paddle.framework.core
1590-
15911605
core.default_cpu_generator().set_state(checkpoint_rng_state["cpu"])
15921606
if core.is_compiled_with_cuda():
15931607
if not len(checkpoint_rng_state["cuda"]) == core.get_cuda_device_count():

0 commit comments

Comments
 (0)