Skip to content

Commit 5c1779c

Browse files
authored
[Feature] Add hardware flops for pretraining (#9069)
* fix hardware tflops. * Support mfu for pretraining.
1 parent e2f4c33 commit 5c1779c

File tree

9 files changed

+240
-1
lines changed

9 files changed

+240
-1
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,15 +1354,22 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
13541354
)
13551355
num_steps = self.state.global_step - self._globalstep_last_logged
13561356
seq_length = None
1357+
model_flops = None
13571358
if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"):
13581359
seq_length = getattr(self.model.config, "seq_length", None)
1360+
try:
1361+
model_flops = self.model.get_hardware_flops(seq_length=seq_length, recompute=self.args.recompute)
1362+
except NotImplementedError:
1363+
model_flops = None
1364+
13591365
logs.update(
13601366
speed_metrics(
13611367
"interval",
13621368
self._globalstep_last_start_time,
13631369
num_samples=total_train_batch_size * num_steps,
13641370
num_steps=num_steps,
13651371
seq_length=seq_length,
1372+
model_flops=model_flops,
13661373
)
13671374
)
13681375

paddlenlp/trainer/trainer_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def total_processes_number(local_rank):
344344
return 1
345345

346346

347-
def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None):
347+
def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None, model_flops=None):
348348
"""
349349
Measure and return speed performance metrics.
350350
@@ -365,6 +365,11 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_lengt
365365
if seq_length is not None:
366366
tokens_per_second_per_device = samples_per_second * seq_length / paddle.distributed.get_world_size()
367367
result[f"{split}_tokens_per_second_per_device"] = round(tokens_per_second_per_device, 4)
368+
if model_flops is not None:
369+
result[f"{split}_hardware_tflops_per_device"] = round(
370+
tokens_per_second_per_device * model_flops / seq_length / 2**40, 2
371+
)
372+
368373
if num_steps is not None:
369374
steps_per_second = num_steps / runtime
370375
result[f"{split}_steps_per_second"] = round(steps_per_second, 4)

paddlenlp/transformers/gemma/modeling.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from .. import linear_utils
5656
from ..linear_utils import Linear
5757
from ..segment_parallel_utils import ReshardLayer
58+
from ..utils import caculate_llm_flops
5859
from .configuration import (
5960
GEMMA_PRETRAINED_INIT_CONFIGURATION,
6061
GEMMA_PRETRAINED_RESOURCE_FILES_MAP,
@@ -1074,6 +1075,39 @@ def __init__(self, config: GemmaConfig):
10741075

10751076
self.gradient_checkpointing = False
10761077

1078+
def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
1079+
if seq_length is None:
1080+
if hasattr(self.config, "seq_length"):
1081+
seq_length = self.config.seq_length
1082+
else:
1083+
seq_length = 2048
1084+
1085+
return caculate_llm_flops(
1086+
hidden_size=self.config.hidden_size,
1087+
intermediate_size=self.config.intermediate_size,
1088+
layer_num=self.config.num_hidden_layers,
1089+
vocab_size=self.config.vocab_size,
1090+
seq_length=seq_length,
1091+
recompute=False,
1092+
)
1093+
1094+
def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
1095+
if seq_length is None:
1096+
if hasattr(self.config, "seq_length"):
1097+
seq_length = self.config.seq_length
1098+
else:
1099+
seq_length = 2048
1100+
1101+
return caculate_llm_flops(
1102+
hidden_size=self.config.hidden_size,
1103+
intermediate_size=self.config.intermediate_size,
1104+
layer_num=self.config.num_hidden_layers,
1105+
vocab_size=self.config.vocab_size,
1106+
seq_length=seq_length,
1107+
recompute=recompute,
1108+
recompute_granularity=self.config.recompute_granularity,
1109+
)
1110+
10771111
def get_input_embeddings(self):
10781112
return self.embed_tokens
10791113

paddlenlp/transformers/gpt/modeling.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
TokenClassifierOutput,
5454
)
5555
from ..model_utils import dy2st_nocheck_guard_context
56+
from ..utils import caculate_llm_flops
5657
from .configuration import (
5758
GPT_PRETRAINED_INIT_CONFIGURATION,
5859
GPT_PRETRAINED_RESOURCE_FILES_MAP,
@@ -1105,6 +1106,39 @@ def __init__(self, config: GPTConfig):
11051106
decoder_layers,
11061107
)
11071108

1109+
def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
1110+
if seq_length is None:
1111+
if hasattr(self.config, "seq_length"):
1112+
seq_length = self.config.seq_length
1113+
else:
1114+
seq_length = 2048
1115+
1116+
return caculate_llm_flops(
1117+
hidden_size=self.config.hidden_size,
1118+
intermediate_size=self.config.intermediate_size,
1119+
layer_num=self.config.num_hidden_layers,
1120+
vocab_size=self.config.vocab_size,
1121+
seq_length=seq_length,
1122+
recompute=False,
1123+
)
1124+
1125+
def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
1126+
if seq_length is None:
1127+
if hasattr(self.config, "seq_length"):
1128+
seq_length = self.config.seq_length
1129+
else:
1130+
seq_length = 2048
1131+
1132+
return caculate_llm_flops(
1133+
hidden_size=self.config.hidden_size,
1134+
intermediate_size=self.config.intermediate_size,
1135+
layer_num=self.config.num_hidden_layers,
1136+
vocab_size=self.config.vocab_size,
1137+
seq_length=seq_length,
1138+
recompute=recompute,
1139+
recompute_granularity=self.config.recompute_granularity,
1140+
)
1141+
11081142
def get_input_embeddings(self):
11091143
return self.embeddings.word_embeddings
11101144

paddlenlp/transformers/llama/modeling.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def swiglu(x, y=None):
7070
from .. import linear_utils
7171
from ..linear_utils import Linear
7272
from ..segment_parallel_utils import ReshardLayer
73+
from ..utils import caculate_llm_flops
7374
from .configuration import (
7475
LLAMA_PRETRAINED_INIT_CONFIGURATION,
7576
LLAMA_PRETRAINED_RESOURCE_FILES_MAP,
@@ -1468,6 +1469,39 @@ def __init__(self, config: LlamaConfig):
14681469

14691470
self.gradient_checkpointing = False
14701471

1472+
def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
1473+
if seq_length is None:
1474+
if hasattr(self.config, "seq_length"):
1475+
seq_length = self.config.seq_length
1476+
else:
1477+
seq_length = 2048
1478+
1479+
return caculate_llm_flops(
1480+
hidden_size=self.config.hidden_size,
1481+
intermediate_size=self.config.intermediate_size,
1482+
layer_num=self.config.num_hidden_layers,
1483+
vocab_size=self.config.vocab_size,
1484+
seq_length=seq_length,
1485+
recompute=False,
1486+
)
1487+
1488+
def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
1489+
if seq_length is None:
1490+
if hasattr(self.config, "seq_length"):
1491+
seq_length = self.config.seq_length
1492+
else:
1493+
seq_length = 2048
1494+
1495+
return caculate_llm_flops(
1496+
hidden_size=self.config.hidden_size,
1497+
intermediate_size=self.config.intermediate_size,
1498+
layer_num=self.config.num_hidden_layers,
1499+
vocab_size=self.config.vocab_size,
1500+
seq_length=seq_length,
1501+
recompute=recompute,
1502+
recompute_granularity=self.config.recompute_granularity,
1503+
)
1504+
14711505
def get_input_embeddings(self):
14721506
return self.embed_tokens
14731507

paddlenlp/transformers/model_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,20 @@ def get_memory_footprint(self, return_buffers=True):
11021102
mem = mem + mem_bufs
11031103
return mem
11041104

1105+
def get_model_flops(self, *args, **kwargs):
1106+
base_model = getattr(self, self.base_model_prefix, self)
1107+
if base_model is not self:
1108+
return base_model.get_model_flops()
1109+
1110+
raise NotImplementedError(f"model of {type(base_model)} has not implemented the `get_model_flops`")
1111+
1112+
def get_hardware_flops(self, *args, **kwargs):
1113+
base_model = getattr(self, self.base_model_prefix, self)
1114+
if base_model is not self:
1115+
return base_model.get_hardware_flops()
1116+
1117+
raise NotImplementedError(f"model of {type(base_model)} has not implemented the `get_hardware_flops`")
1118+
11051119
def get_input_embeddings(self) -> nn.Embedding:
11061120
"""get input embedding of model
11071121

paddlenlp/transformers/qwen/modeling.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def swiglu(x, y=None):
4949
from .. import linear_utils
5050
from ..linear_utils import Linear
5151
from ..model_outputs import ModelOutput
52+
from ..utils import caculate_llm_flops
5253
from .configuration import QWenConfig
5354

5455
try:
@@ -690,6 +691,39 @@ def __init__(self, config):
690691
)
691692
self.ln_f = QWenRMSNorm(config)
692693

694+
def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
695+
if seq_length is None:
696+
if hasattr(self.config, "seq_length"):
697+
seq_length = self.config.seq_length
698+
else:
699+
seq_length = 2048
700+
701+
return caculate_llm_flops(
702+
hidden_size=self.config.hidden_size,
703+
intermediate_size=self.config.intermediate_size,
704+
layer_num=self.config.num_hidden_layers,
705+
vocab_size=self.config.vocab_size,
706+
seq_length=seq_length,
707+
recompute=False,
708+
)
709+
710+
def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
711+
if seq_length is None:
712+
if hasattr(self.config, "seq_length"):
713+
seq_length = self.config.seq_length
714+
else:
715+
seq_length = 2048
716+
717+
return caculate_llm_flops(
718+
hidden_size=self.config.hidden_size,
719+
intermediate_size=self.config.intermediate_size,
720+
layer_num=self.config.num_hidden_layers,
721+
vocab_size=self.config.vocab_size,
722+
seq_length=seq_length,
723+
recompute=recompute,
724+
recompute_granularity=self.config.recompute_granularity,
725+
)
726+
693727
def get_input_embeddings(self):
694728
return self.wte
695729

paddlenlp/transformers/qwen2/modeling.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
TokenClassifierOutput,
4545
)
4646
from ..model_utils import PretrainedModel, register_base_model
47+
from ..utils import caculate_llm_flops
4748
from .configuration import Qwen2Config
4849

4950
try:
@@ -914,6 +915,39 @@ def __init__(self, config: Qwen2Config):
914915
)
915916
self.norm = Qwen2RMSNorm(config)
916917

918+
def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
919+
if seq_length is None:
920+
if hasattr(self.config, "seq_length"):
921+
seq_length = self.config.seq_length
922+
else:
923+
seq_length = 2048
924+
925+
return caculate_llm_flops(
926+
hidden_size=self.config.hidden_size,
927+
intermediate_size=self.config.intermediate_size,
928+
layer_num=self.config.num_hidden_layers,
929+
vocab_size=self.config.vocab_size,
930+
seq_length=seq_length,
931+
recompute=False,
932+
)
933+
934+
def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
935+
if seq_length is None:
936+
if hasattr(self.config, "seq_length"):
937+
seq_length = self.config.seq_length
938+
else:
939+
seq_length = 2048
940+
941+
return caculate_llm_flops(
942+
hidden_size=self.config.hidden_size,
943+
intermediate_size=self.config.intermediate_size,
944+
layer_num=self.config.num_hidden_layers,
945+
vocab_size=self.config.vocab_size,
946+
seq_length=seq_length,
947+
recompute=recompute,
948+
recompute_granularity=self.config.recompute_granularity,
949+
)
950+
917951
def get_input_embeddings(self):
918952
return self.embed_tokens
919953

paddlenlp/transformers/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,3 +958,46 @@ def __repr__(self):
958958
if self.err_buf:
959959
msg += f"stderr: {self.err}\n"
960960
return msg
961+
962+
963+
def caculate_llm_flops(
964+
hidden_size,
965+
intermediate_size,
966+
layer_num,
967+
vocab_size,
968+
batch_size=1,
969+
seq_length=None,
970+
recompute=False,
971+
recompute_granularity=None,
972+
):
973+
974+
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
975+
flops_per_transformer = 0
976+
flops_recompute_transformer = 0
977+
978+
# qkvo matmul
979+
flops_qkvo_matmul = seq_length * hidden_size**2 * 4
980+
981+
# [b,s,h] [b,h,s] bs^2h
982+
# [b,s,s] [b,s,h] bs^2h
983+
# q_states * k_states + attn_weight * v_states
984+
flops_core_attn = seq_length**2 * hidden_size * 2
985+
986+
# swiglu, matmul + dot
987+
flops_ffn = seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size
988+
989+
flops_per_transformer = flops_qkvo_matmul + flops_core_attn + flops_ffn
990+
if recompute:
991+
if recompute_granularity == "full":
992+
flops_recompute_transformer = flops_per_transformer
993+
if recompute_granularity == "full_attn":
994+
flops_recompute_transformer = flops_qkvo_matmul + flops_core_attn
995+
if recompute_granularity == "core_attn":
996+
flops_recompute_transformer = flops_core_attn
997+
998+
# final loggits
999+
flops_loggits = seq_length * hidden_size * vocab_size
1000+
1001+
# 2 for mul + add in matmul
1002+
# 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y
1003+
return 2 * batch_size * (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits)

0 commit comments

Comments
 (0)