From f0507f6328e16b3a513787cf7333fef8765f1f56 Mon Sep 17 00:00:00 2001 From: pkhk-1 Date: Wed, 26 Mar 2025 03:23:17 +0000 Subject: [PATCH 1/4] add token ips --- .../llava/language_model/llava_llama.py | 10 + paddlemix/trainer/benchmark_callback.py | 388 ++++++++++++++++++ paddlemix/trainer/llava_trainer.py | 78 +--- 3 files changed, 412 insertions(+), 64 deletions(-) create mode 100644 paddlemix/trainer/benchmark_callback.py diff --git a/paddlemix/models/llava/language_model/llava_llama.py b/paddlemix/models/llava/language_model/llava_llama.py index d0f481b63..366aa3c10 100644 --- a/paddlemix/models/llava/language_model/llava_llama.py +++ b/paddlemix/models/llava/language_model/llava_llama.py @@ -100,6 +100,16 @@ def forward( input_ids, position_ids, attention_mask, past_key_values, labels, images, image_size ) + # 通过attention_mask计算有效token数量 + if attention_mask is not None: + # 统计当前batch的有效token数(排除padding) + current_batch_tokens = attention_mask.sum().item() # shape: (batch_size, seq_len) + else: + # 如果没有padding,直接取inputs_embeds的batch_size*seq_length + current_batch_tokens = inputs_embeds.size(0) * inputs_embeds.size(1) + self.efficient_token_count = current_batch_tokens + self.input_shape = inputs_embeds.shape + return super().forward( input_ids=input_ids, attention_mask=attention_mask, diff --git a/paddlemix/trainer/benchmark_callback.py b/paddlemix/trainer/benchmark_callback.py new file mode 100644 index 000000000..6077480e2 --- /dev/null +++ b/paddlemix/trainer/benchmark_callback.py @@ -0,0 +1,388 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlenlp.trainer.trainer import TrainerCallback +from paddlenlp.trainer.trainer_utils import ShardingOption +from paddlenlp.utils.log import logger + + +"""utils for log""" +import os +import time +import numpy as np +import paddle +import paddle.distributed.fleet as fleet +from paddlenlp.trainer.plugins.timer import get_timers + +def get_memory_info(): + """get_memory_info""" + divisor = 2**30 + return ( + paddle.device.cuda.memory_allocated() / divisor, + paddle.device.cuda.max_memory_allocated() / divisor, + paddle.device.cuda.memory_reserved() / divisor, + paddle.device.cuda.max_memory_reserved() / divisor, + ) + + +class Statistical(object): + """Statistical + """ + def __init__(self, buffer_size, skip_step, approximate=False): + self.step = 0 + self.local_step = 0 + self.is_first_period = True + print('buffer size is %d and skip step is %d.'%(buffer_size, skip_step)) + + assert skip_step < buffer_size + self.buffer_size = buffer_size + self.skip_step = skip_step + self.approximate = approximate + + self.local_tokens = np.zeros([buffer_size], dtype=np.int64) + self.local_samples = np.zeros([buffer_size], dtype=np.int64) + self.efficient_tokens = np.zeros([buffer_size], dtype=np.int64) + self.tokens = np.zeros([buffer_size], dtype=np.int64) + self.samples = np.zeros([buffer_size], dtype=np.int64) + self.durations = np.zeros([buffer_size], dtype=np.float64) + if int(os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1: + assert ( + paddle.distributed.is_initialized() + ), "please call fleet.init() before" + if hasattr(fleet.fleet, "_hcg"): + hcg = fleet.get_hybrid_communicate_group() + + dp_size = hcg.get_data_parallel_world_size() + sharding_size = hcg.get_sharding_parallel_world_size() + self.dp_size = dp_size * sharding_size + self.world_size = paddle.distributed.get_world_size() + self.groups = [] + if dp_size > 1: + self.groups.append(hcg.get_data_parallel_group()) + if sharding_size > 1: + self.groups.append(hcg.get_sharding_parallel_group()) + else: + self.dp_size = paddle.distributed.get_world_size() + self.world_size = self.dp_size + self.groups = [] + + else: + self.dp_size = 1 + self.world_size = 1 + self.groups = [] + + def reset(self): + """reset + """ + if self.local_step > 0: + self.is_first_period = False + self.local_step = 0 + self.local_tokens.fill(0) + self.local_samples.fill(0) + self.efficient_tokens.fill(0) + self.tokens.fill(0) + self.samples.fill(0) + self.durations.fill(0.0) + + def _get_global(self, tokens_num, batch_size): + if self.dp_size == 1 or self.approximate: + return tokens_num * self.dp_size, batch_size * self.dp_size + + timers = get_timers() + if timers: + timers("all-reduce-token-bz").start() + x = paddle.to_tensor([tokens_num, batch_size], dtype=paddle.int64) + for g in self.groups: + paddle.distributed.stream.all_reduce(x, group=g, sync_op=True, use_calc_stream=True) + if timers: + timers("all-reduce-token-bz").stop() + return x.numpy().tolist() + + def add(self, start_time, efficient_tokens_num, tokens_num, batch_size): + """add """ + assert ( + self.local_step < self.buffer_size + ), "the step number exceeds the ckpt saving interval" + global_tokens_num, global_batch_size = self._get_global(tokens_num, batch_size) + global_efficient_tokens_num, _ = self._get_global(efficient_tokens_num, batch_size) + + duration = time.time() - start_time + self.durations[self.local_step] = duration + self.local_tokens[self.local_step] = tokens_num + self.local_samples[self.local_step] = batch_size + self.efficient_tokens[self.local_step] = global_efficient_tokens_num + self.tokens[self.local_step] = global_tokens_num + self.samples[self.local_step] = global_batch_size + self.step += 1 + self.local_step += 1 + return global_tokens_num, global_batch_size + + def get_tokens_per_sec_per_card(self): + """get_tokens_per_sec_per_card""" + return ( + self.tokens[self.local_step - 1] + / self.durations[self.local_step - 1] + / self.world_size + ) + + def get_efficient_tokens_per_sec_per_card(self): + """get_efficient_tokens_per_sec_per_card""" + return ( + self.efficient_tokens[self.local_step - 1] + / self.durations[self.local_step - 1] + / self.world_size + ) + + def get_avg_tokens_per_sec_per_card(self): + """get_avg_tokens_per_sec_per_card""" + if self.step <= self.skip_step: + return self.get_tokens_per_sec_per_card() + + start = self.skip_step if self.is_first_period else 0 + return ( + np.sum(self.tokens[start : self.local_step]) + / np.sum(self.durations[start : self.local_step]) + / self.world_size + ) + + def get_avg_efficient_tokens_per_sec_per_card(self): + """get_avg_tokens_per_sec_per_card""" + if self.step <= self.skip_step: + return self.get_efficient_tokens_per_sec_per_card() + + start = self.skip_step if self.is_first_period else 0 + return ( + np.sum(self.efficient_tokens[start : self.local_step]) + / np.sum(self.durations[start : self.local_step]) + / self.world_size + ) + + def get_samples_per_sec_per_card(self): + """get_samples_per_sec_per_card""" + return ( + self.samples[self.local_step - 1] + / self.durations[self.local_step - 1] + / self.world_size + ) + + def get_avg_samples_per_sec_per_card(self): + """get_avg_samples_per_sec_per_card""" + if self.step <= self.skip_step: + return self.get_samples_per_sec_per_card() + + start = self.skip_step if self.is_first_period else 0 + return ( + np.sum(self.samples[start : self.local_step]) + / np.sum(self.durations[start : self.local_step]) + / self.world_size + ) + + def get_tokens(self): + """get_tokens""" + return self.tokens[: self.local_step] + + def get_efficient_tokens(self): + """get_efficient_tokens""" + return self.efficient_tokens[: self.local_step] + + def get_samples(self): + """get_samples""" + return self.samples[: self.local_step] + + def get_durations(self): + """get_durations""" + return self.durations[: self.local_step] + + def get_total_tokens_per_card(self): + """get_total_tokens_per_card""" + return np.sum(self.tokens[: self.local_step]) / self.world_size + + def get_total_efficient_tokens_per_card(self): + """get_total_efficient_tokens_per_card""" + return np.sum(self.efficient_tokens[: self.local_step]) / self.world_size + + def get_total_samples_per_card(self): + """get_total_samples_per_card""" + return np.sum(self.samples[: self.local_step]) / self.world_size + + def get_skip_duration(self): + """get_skip_duration""" + if self.is_first_period: + return np.sum(self.durations[: min(self.local_step, self.skip_step)]) + else: + return 0.0 + + + def get_result(self): + """get_result""" + runtime = np.sum(self.durations[: self.local_step]) + local_samples = np.sum(self.local_samples[: self.local_step]) + local_tokens = np.sum(self.local_tokens[: self.local_step]) + global_samples = np.sum(self.samples[: self.local_step]) + global_tokens = np.sum(self.tokens[: self.local_step]) + global_efficient_tokens = np.sum(self.efficient_tokens[: self.local_step]) + + alloc, max_alloc, reserved, max_reserved = get_memory_info() + result = { + f'runtime': runtime, + f"local_samples": local_samples, + f"global_samples": global_samples, + f"local_tokens": local_tokens, + f"global_tokens": global_tokens, + f"global_efficient_tokens": global_efficient_tokens, + f"samples_per_sec_per_card": self.get_samples_per_sec_per_card(), + f"avg_samples_per_sec_per_card": self.get_avg_samples_per_sec_per_card(), + f"efficient_tokens_per_sec_per_card": self.get_efficient_tokens_per_sec_per_card(), + f"tokens_per_sec_per_card": self.get_tokens_per_sec_per_card(), + f"avg_tokens_per_sec_per_card": self.get_avg_tokens_per_sec_per_card(), + f"avg_efficient_tokens_per_sec_per_card": self.get_avg_efficient_tokens_per_sec_per_card(), + "memory_allocated_gb": alloc, + "max_memory_allocated_gb": max_alloc, + "memory_reserved_gb": reserved, + "max_memory_reserved_gb": max_reserved, + } + + return result + + +class BenchmarkCallback(TrainerCallback): + """ + used to benchmark the training process. + """ + + ACC_SAMPLES = "acc_global_samples" + ACC_TOKENS = "acc_global_tokens" + + def __init__(self, trainer, save_steps, skip_step, benchmark_mode): + super().__init__() + self.trainer = trainer + self.state = Statistical(save_steps, skip_step) + self.efficient_token_count = 0 + self.cur_tokens = 0 + self.cur_samples = 0 + + def set_save_time(self, save_time): + """ + set the time to save model. + """ + self.save_time = save_time + + def on_train_begin(self, args, state, control, **kwargs): + """ + record the start time of training. + """ + if state.trial_params is None: + state.trial_params = {} + + if self.ACC_SAMPLES not in state.trial_params: + state.trial_params[self.ACC_SAMPLES] = 0 + if self.ACC_TOKENS not in state.trial_params: + state.trial_params[self.ACC_TOKENS] = 0 + + if paddle.distributed.is_initialized(): + if hasattr(fleet.fleet, "_hcg"): + dp_degree = fleet.get_hybrid_communicate_group().get_data_parallel_world_size() + else: + dp_degree = paddle.distributed.get_world_size() + assert dp_degree <= 1, f"data_parallel_degree should be 1 but got {dp_degree}" + paddle.distributed.barrier() + self.end_save_time = time.time() + + def on_epoch_begin(self, args, state, control, **kwargs): + """ + record the start time of epoch. + """ + self.epoch_start = time.time() + self.batch_start = time.time() + + def on_substep_end(self, args, state, control, **kwargs): + """ + record the start time of batch. + """ + model = kwargs["model"] + batch_size, seq_length, _ = model.input_shape + self.efficient_token_count = model.efficient_token_count + self.cur_tokens += batch_size * seq_length + self.cur_samples += batch_size + + def on_step_end(self, args, state, control, **kwargs): + """ + record the start time of sub-batch. + """ + self.on_substep_end(args, state, control, **kwargs) + + tokens, batches = self.state.add( + start_time=self.batch_start, efficient_tokens_num=self.efficient_token_count, tokens_num=self.cur_tokens, batch_size=self.cur_samples + ) + state.trial_params[self.ACC_SAMPLES] += batches + state.trial_params[self.ACC_TOKENS] += tokens + + self.efficient_token_count = 0 + self.cur_tokens = 0 + self.cur_samples = 0 + + self.batch_start = time.time() + if control.should_log: + self.maybe_log_save_evaluate_start = time.time() + + def on_save(self, args, state, control, **kwargs): + """ + record the infomation of saving model. + """ + pass + + def on_log(self, args, state, control, logs=None, **kwargs): + """ + record the infomation of a logging step. + """ + if benchmark_mode: + logs.update(self.state.get_result()) + logs[self.ACC_SAMPLES] = state.trial_params[self.ACC_SAMPLES] + logs[self.ACC_TOKENS] = state.trial_params[self.ACC_TOKENS] + + max_mem_reserved_msg = ( + f"max_mem_reserved: {logs['max_memory_reserved_gb']} GB," + ) + max_mem_allocated_msg = ( + f"max_mem_allocated: {logs['max_memory_allocated_gb']} GB" + ) + + logger.info( + "global step %d, loss: %.5f, interval_samples_per_second: %.5f, ips: %.5f, %s %s" + % ( + state.global_step, + logs["loss"], + logs["interval_samples_per_second"], + logs["avg_efficient_tokens_per_sec_per_card"], + max_mem_reserved_msg, + max_mem_allocated_msg, + ) + ) + + def _log(self, logs): + """ + record the information accurately and neatly. + """ + logs_str = [] + logs_str = [] + for k, v in logs.items(): + if isinstance(v, float): + if abs(v) < 1e-3: + v = f"{v:e}" + elif abs(v) > 100: + v = f"{v:.04f}" + else: + v = f"{v:.06f}" + logs_str.append(f"{k}: {v}") + logger.info(", ".join(logs_str)) \ No newline at end of file diff --git a/paddlemix/trainer/llava_trainer.py b/paddlemix/trainer/llava_trainer.py index b975a92e2..36793d00e 100644 --- a/paddlemix/trainer/llava_trainer.py +++ b/paddlemix/trainer/llava_trainer.py @@ -24,61 +24,9 @@ ) from paddlenlp.trainer.trainer import Trainer, has_length from paddlenlp.trainer.trainer_utils import ShardingOption -from paddlenlp.trainer.integrations import TrainerCallback from paddlenlp.trainer import PrinterCallback, ProgressCallback from paddlenlp.utils.log import logger - -class BenchmarkCallback(TrainerCallback): - def __init__(self, benchmark=False): - self.benchmark = benchmark - - def on_train_begin(self, args, state, control, **kwargs): - # assert args.gradient_accumulation_steps == 1 and not args.do_eval and not args.do_predict - if self.benchmark: - pass - - def on_epoch_begin(self, args, state, control, **kwargs): - if self.benchmark: - pass - - def on_step_begin(self, args, state, control, **kwargs): - if self.benchmark: - pass - - def on_step_end(self, args, state, control, **kwargs): - if self.benchmark: - pass - def on_log(self, args, state, control, logs=None, **kwargs): - if self.benchmark: - if logs is not None and "interval_samples_per_second" in logs: - max_mem_reserved_msg = "" - max_mem_allocated_msg = "" - - if paddle.device.is_compiled_with_cuda(): - max_mem_reserved_msg = ( - f"max_mem_reserved: {paddle.device.cuda.max_memory_reserved() // (1024 ** 2)} MB," - ) - max_mem_allocated_msg = ( - f"max_mem_allocated: {paddle.device.cuda.max_memory_allocated() // (1024 ** 2)} MB" - ) - - logger.info( - "global step %d, loss: %.5f, ips: %.5f, %s %s" - % ( - state.global_step, - logs["loss"], - logs["interval_samples_per_second"], - max_mem_reserved_msg, - max_mem_allocated_msg, - ) - ) - else: - logger.info(logs) - - def on_epoch_end(self, args, state, control, **kwargs): - if self.benchmark: - pass - +from paddlemix.trainer.benchmark_callback import BenchmarkCallback def split_to_even_chunks(indices, lengths, num_chunks): """ @@ -171,19 +119,21 @@ def __iter__(self): class LLaVATrainer(Trainer): - def __init__(self,**kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) if self.args.benchmark: - self.add_callback( - BenchmarkCallback( - benchmark=self.args.benchmark - ) - ) - if self.args.benchmark: - if self.args.disable_tqdm: - self.pop_callback(PrinterCallback) - else: - self.pop_callback(ProgressCallback) + # self.benchmark_callback = BenchmarkCallback(self, self.args.save_steps, skip_step=self.args.benchmark_skip_steps, self.args.benchmark) + self.benchmark_callback = BenchmarkCallback( + self, + self.args.max_steps if self.args.max_steps>1 else 1000, + 1, + self.args.benchmark) + self.add_callback(self.benchmark_callback) + if self.args.disable_tqdm: + self.pop_callback(PrinterCallback) + else: + self.pop_callback(ProgressCallback) + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None From f7391719fb279962d64cc7f4ff38f09293c5a9a9 Mon Sep 17 00:00:00 2001 From: pkhk-1 Date: Wed, 26 Mar 2025 06:02:40 +0000 Subject: [PATCH 2/4] add log --- paddlemix/trainer/benchmark_callback.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddlemix/trainer/benchmark_callback.py b/paddlemix/trainer/benchmark_callback.py index 6077480e2..af3bbc0cc 100644 --- a/paddlemix/trainer/benchmark_callback.py +++ b/paddlemix/trainer/benchmark_callback.py @@ -271,6 +271,7 @@ def __init__(self, trainer, save_steps, skip_step, benchmark_mode): self.efficient_token_count = 0 self.cur_tokens = 0 self.cur_samples = 0 + self.benchmark_mode = benchmark_mode def set_save_time(self, save_time): """ @@ -346,7 +347,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): """ record the infomation of a logging step. """ - if benchmark_mode: + if self.benchmark_mode: logs.update(self.state.get_result()) logs[self.ACC_SAMPLES] = state.trial_params[self.ACC_SAMPLES] logs[self.ACC_TOKENS] = state.trial_params[self.ACC_TOKENS] @@ -359,12 +360,15 @@ def on_log(self, args, state, control, logs=None, **kwargs): ) logger.info( - "global step %d, loss: %.5f, interval_samples_per_second: %.5f, ips: %.5f, %s %s" + "global step %d, loss: %.5f, interval_samples_per_second: %.5f, ips: %.5f, efficient_tokens_per_sec_per_card: %.5f, avg_tokens_per_sec_per_card: %.5f, tokens_per_sec_per_card: %.5f, %s %s" % ( state.global_step, logs["loss"], logs["interval_samples_per_second"], logs["avg_efficient_tokens_per_sec_per_card"], + logs["efficient_tokens_per_sec_per_card"], + logs["avg_tokens_per_sec_per_card"], + logs["tokens_per_sec_per_card"], max_mem_reserved_msg, max_mem_allocated_msg, ) From 42935f3fca6eee04380674da2d1e18d5cb356057 Mon Sep 17 00:00:00 2001 From: pkhk-1 Date: Wed, 26 Mar 2025 06:40:51 +0000 Subject: [PATCH 3/4] add some log --- paddlemix/trainer/benchmark_callback.py | 19 +------------------ paddlemix/trainer/llava_trainer.py | 2 +- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/paddlemix/trainer/benchmark_callback.py b/paddlemix/trainer/benchmark_callback.py index af3bbc0cc..dce11bbd7 100644 --- a/paddlemix/trainer/benchmark_callback.py +++ b/paddlemix/trainer/benchmark_callback.py @@ -347,7 +347,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): """ record the infomation of a logging step. """ - if self.benchmark_mode: + if self.benchmark_mode and "loss" in logs: logs.update(self.state.get_result()) logs[self.ACC_SAMPLES] = state.trial_params[self.ACC_SAMPLES] logs[self.ACC_TOKENS] = state.trial_params[self.ACC_TOKENS] @@ -373,20 +373,3 @@ def on_log(self, args, state, control, logs=None, **kwargs): max_mem_allocated_msg, ) ) - - def _log(self, logs): - """ - record the information accurately and neatly. - """ - logs_str = [] - logs_str = [] - for k, v in logs.items(): - if isinstance(v, float): - if abs(v) < 1e-3: - v = f"{v:e}" - elif abs(v) > 100: - v = f"{v:.04f}" - else: - v = f"{v:.06f}" - logs_str.append(f"{k}: {v}") - logger.info(", ".join(logs_str)) \ No newline at end of file diff --git a/paddlemix/trainer/llava_trainer.py b/paddlemix/trainer/llava_trainer.py index 36793d00e..154035776 100644 --- a/paddlemix/trainer/llava_trainer.py +++ b/paddlemix/trainer/llava_trainer.py @@ -125,7 +125,7 @@ def __init__(self, **kwargs): # self.benchmark_callback = BenchmarkCallback(self, self.args.save_steps, skip_step=self.args.benchmark_skip_steps, self.args.benchmark) self.benchmark_callback = BenchmarkCallback( self, - self.args.max_steps if self.args.max_steps>1 else 1000, + self.args.max_steps if self.args.max_steps>1 else 10000, 1, self.args.benchmark) self.add_callback(self.benchmark_callback) From f328318708168d4e1ad021504683d740bbfcf25f Mon Sep 17 00:00:00 2001 From: pkhk-1 Date: Wed, 26 Mar 2025 07:13:26 +0000 Subject: [PATCH 4/4] fix log output --- paddlemix/examples/llava/pretrain.py | 23 ++--------------- .../examples/llava/supervised_finetune.py | 25 +++---------------- 2 files changed, 5 insertions(+), 43 deletions(-) diff --git a/paddlemix/examples/llava/pretrain.py b/paddlemix/examples/llava/pretrain.py index f0f27f548..58a3aff0b 100755 --- a/paddlemix/examples/llava/pretrain.py +++ b/paddlemix/examples/llava/pretrain.py @@ -165,29 +165,10 @@ def main(): checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) if training_args.benchmark: - - def get_paddle_memory_info(): - """get_memory_info""" - divisor = 2**30 - return ( - paddle.device.cuda.memory_allocated() / divisor, - paddle.device.cuda.max_memory_allocated() / divisor, - paddle.device.cuda.memory_reserved() / divisor, - paddle.device.cuda.max_memory_reserved() / divisor, - ) - - memory_allocated, max_memory_allocated, memory_reserved, max_memory_reserved = get_paddle_memory_info() - - logger.info( - f"memory_allocated:{memory_allocated}GB, max_memory_allocated: {max_memory_allocated}GB, memory_reserved:{memory_reserved}GB, max_memory_reserved: {max_memory_reserved}GB \n" - ) total_effective_samples = total_samples * training_args.num_train_epochs effective_samples_per_second = total_effective_samples / train_result.metrics["train_runtime"] - mem_gpu = ( - train_result.metrics["train_mem_gpu_peaked_delta"] + train_result.metrics["train_mem_gpu_alloc_delta"] - ) - logger.info(f"ips: {effective_samples_per_second} ") - logger.info(f"train_mem_gpu_peaked: {int(mem_gpu/ (2**20))} MB") + logger.info(f"Effective_samples_per_second: {effective_samples_per_second} ") + logger.info(f"Train_runtime: {train_result.metrics['train_runtime']}") logger.info("Benchmark done.") else: trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) diff --git a/paddlemix/examples/llava/supervised_finetune.py b/paddlemix/examples/llava/supervised_finetune.py index c9e016566..d40511c09 100644 --- a/paddlemix/examples/llava/supervised_finetune.py +++ b/paddlemix/examples/llava/supervised_finetune.py @@ -190,30 +190,11 @@ def main(): checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) if training_args.benchmark: - - def get_paddle_memory_info(): - """get_memory_info""" - divisor = 2**30 - return ( - paddle.device.cuda.memory_allocated() / divisor, - paddle.device.cuda.max_memory_allocated() / divisor, - paddle.device.cuda.memory_reserved() / divisor, - paddle.device.cuda.max_memory_reserved() / divisor, - ) - - memory_allocated, max_memory_allocated, memory_reserved, max_memory_reserved = get_paddle_memory_info() - - logger.info( - f"memory_allocated:{memory_allocated}GB, max_memory_allocated: {max_memory_allocated}GB, memory_reserved:{memory_reserved}GB, max_memory_reserved: {max_memory_reserved}GB \n" - ) - total_effective_samples = total_samples * training_args.num_train_epochs effective_samples_per_second = total_effective_samples / train_result.metrics["train_runtime"] - mem_gpu = ( - train_result.metrics["train_mem_gpu_peaked_delta"] + train_result.metrics["train_mem_gpu_alloc_delta"] - ) - logger.info(f"ips: {effective_samples_per_second} ") - logger.info(f"train_mem_gpu_peaked: {int(mem_gpu/ (2**20))} MB") + + logger.info(f"Effective_samples_per_second: {effective_samples_per_second} ") + logger.info(f"Train_runtime: {train_result.metrics['train_runtime']}") logger.info("Benchmark done.") else: trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)