Skip to content

Commit 126bdd3

Browse files
committed
refactor
1 parent 361557c commit 126bdd3

File tree

6 files changed

+128
-310
lines changed

6 files changed

+128
-310
lines changed

paddlenlp/transformers/gemma/modeling.py

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from .. import linear_utils
5656
from ..linear_utils import Linear
5757
from ..segment_parallel_utils import ReshardLayer
58+
from ..utils import caculate_llm_flops
5859
from .configuration import (
5960
GEMMA_PRETRAINED_INIT_CONFIGURATION,
6061
GEMMA_PRETRAINED_RESOURCE_FILES_MAP,
@@ -1075,82 +1076,36 @@ def __init__(self, config: GemmaConfig):
10751076
self.gradient_checkpointing = False
10761077

10771078
def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
1078-
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
1079-
hidden_size = self.config.hidden_size
1080-
intermediate_size = self.config.intermediate_size
1081-
layer_num = self.config.num_hidden_layers
1082-
vocab_size = self.config.vocab_size
1083-
10841079
if seq_length is None:
10851080
if hasattr(self.config, "seq_length"):
10861081
seq_length = self.config.seq_length
10871082
else:
10881083
seq_length = 2048
10891084

1090-
flops_per_transformer = 0
1091-
1092-
# qkvo matmul
1093-
flops_per_transformer += seq_length * hidden_size**2 * 4
1094-
# [b,s,h] [b,h,s] bs^2h
1095-
# [b,s,s] [b,s,h] bs^2h
1096-
# q_states * k_states + attn_weight * v_states
1097-
flops_per_transformer += seq_length**2 * hidden_size * 2
1098-
# swiglu, matmul + dot
1099-
flops_per_transformer += seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size
1100-
1101-
# final loggits
1102-
flops_loggits = seq_length * hidden_size * vocab_size
1103-
1104-
# 2 for mul + add in matmul
1105-
# 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y
1106-
# so, here got 6=2*(1+2)
1107-
return 6 * batch_size * (layer_num * flops_per_transformer + flops_loggits)
1085+
return caculate_llm_flops(
1086+
hidden_size=self.config.hidden_size,
1087+
intermediate_size=self.config.intermediate_size,
1088+
layer_num=self.config.num_hidden_layers,
1089+
vocab_size=self.config.vocab_size,
1090+
seq_length=seq_length,
1091+
recompute=False,
1092+
)
11081093

11091094
def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
1110-
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
1111-
hidden_size = self.config.hidden_size
1112-
intermediate_size = self.config.intermediate_size
1113-
layer_num = self.config.num_hidden_layers
1114-
vocab_size = self.config.vocab_size
1115-
11161095
if seq_length is None:
11171096
if hasattr(self.config, "seq_length"):
11181097
seq_length = self.config.seq_length
11191098
else:
11201099
seq_length = 2048
11211100

1122-
flops_per_transformer = 0
1123-
flops_recompute_transformer = 0
1124-
1125-
# qkvo matmul
1126-
flops_qkvo_matmul = seq_length * hidden_size**2 * 4
1127-
1128-
# [b,s,h] [b,h,s] bs^2h
1129-
# [b,s,s] [b,s,h] bs^2h
1130-
# q_states * k_states + attn_weight * v_states
1131-
flops_core_attn = seq_length**2 * hidden_size * 2
1132-
1133-
# swiglu, matmul + dot
1134-
flops_ffn = seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size
1135-
1136-
flops_per_transformer = flops_qkvo_matmul + flops_core_attn + flops_ffn
1137-
if recompute:
1138-
if self.config.recompute_granularity == "full":
1139-
flops_recompute_transformer = flops_per_transformer
1140-
if self.config.recompute_granularity == "full_attn":
1141-
flops_recompute_transformer = flops_qkvo_matmul + flops_core_attn
1142-
if self.config.recompute_granularity == "core_attn":
1143-
flops_recompute_transformer = flops_core_attn
1144-
1145-
# final loggits
1146-
flops_loggits = seq_length * hidden_size * vocab_size
1147-
1148-
# 2 for mul + add in matmul
1149-
# 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y
1150-
return (
1151-
2
1152-
* batch_size
1153-
* (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits)
1101+
return caculate_llm_flops(
1102+
hidden_size=self.config.hidden_size,
1103+
intermediate_size=self.config.intermediate_size,
1104+
layer_num=self.config.num_hidden_layers,
1105+
vocab_size=self.config.vocab_size,
1106+
seq_length=seq_length,
1107+
recompute=recompute,
1108+
recompute_granularity=self.config.recompute_granularity,
11541109
)
11551110

11561111
def get_input_embeddings(self):

paddlenlp/transformers/gpt/modeling.py

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
TokenClassifierOutput,
5454
)
5555
from ..model_utils import dy2st_nocheck_guard_context
56+
from ..utils import caculate_llm_flops
5657
from .configuration import (
5758
GPT_PRETRAINED_INIT_CONFIGURATION,
5859
GPT_PRETRAINED_RESOURCE_FILES_MAP,
@@ -1106,82 +1107,36 @@ def __init__(self, config: GPTConfig):
11061107
)
11071108

11081109
def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
1109-
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
1110-
hidden_size = self.config.hidden_size
1111-
intermediate_size = self.config.intermediate_size
1112-
layer_num = self.config.num_hidden_layers
1113-
vocab_size = self.config.vocab_size
1114-
11151110
if seq_length is None:
11161111
if hasattr(self.config, "seq_length"):
11171112
seq_length = self.config.seq_length
11181113
else:
11191114
seq_length = 2048
11201115

1121-
flops_per_transformer = 0
1122-
1123-
# qkvo matmul
1124-
flops_per_transformer += seq_length * hidden_size**2 * 4
1125-
# [b,s,h] [b,h,s] bs^2h
1126-
# [b,s,s] [b,s,h] bs^2h
1127-
# q_states * k_states + attn_weight * v_states
1128-
flops_per_transformer += seq_length**2 * hidden_size * 2
1129-
# swiglu, matmul + dot
1130-
flops_per_transformer += seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size
1131-
1132-
# final loggits
1133-
flops_loggits = seq_length * hidden_size * vocab_size
1134-
1135-
# 2 for mul + add in matmul
1136-
# 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y
1137-
# so, here got 6=2*(1+2)
1138-
return 6 * batch_size * (layer_num * flops_per_transformer + flops_loggits)
1116+
return caculate_llm_flops(
1117+
hidden_size=self.config.hidden_size,
1118+
intermediate_size=self.config.intermediate_size,
1119+
layer_num=self.config.num_hidden_layers,
1120+
vocab_size=self.config.vocab_size,
1121+
seq_length=seq_length,
1122+
recompute=False,
1123+
)
11391124

11401125
def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
1141-
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
1142-
hidden_size = self.config.hidden_size
1143-
intermediate_size = self.config.intermediate_size
1144-
layer_num = self.config.num_hidden_layers
1145-
vocab_size = self.config.vocab_size
1146-
11471126
if seq_length is None:
11481127
if hasattr(self.config, "seq_length"):
11491128
seq_length = self.config.seq_length
11501129
else:
11511130
seq_length = 2048
11521131

1153-
flops_per_transformer = 0
1154-
flops_recompute_transformer = 0
1155-
1156-
# qkvo matmul
1157-
flops_qkvo_matmul = seq_length * hidden_size**2 * 4
1158-
1159-
# [b,s,h] [b,h,s] bs^2h
1160-
# [b,s,s] [b,s,h] bs^2h
1161-
# q_states * k_states + attn_weight * v_states
1162-
flops_core_attn = seq_length**2 * hidden_size * 2
1163-
1164-
# swiglu, matmul + dot
1165-
flops_ffn = seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size
1166-
1167-
flops_per_transformer = flops_qkvo_matmul + flops_core_attn + flops_ffn
1168-
if recompute:
1169-
if self.config.recompute_granularity == "full":
1170-
flops_recompute_transformer = flops_per_transformer
1171-
if self.config.recompute_granularity == "full_attn":
1172-
flops_recompute_transformer = flops_qkvo_matmul + flops_core_attn
1173-
if self.config.recompute_granularity == "core_attn":
1174-
flops_recompute_transformer = flops_core_attn
1175-
1176-
# final loggits
1177-
flops_loggits = seq_length * hidden_size * vocab_size
1178-
1179-
# 2 for mul + add in matmul
1180-
# 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y
1181-
return (
1182-
2
1183-
* batch_size
1184-
* (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits)
1132+
return caculate_llm_flops(
1133+
hidden_size=self.config.hidden_size,
1134+
intermediate_size=self.config.intermediate_size,
1135+
layer_num=self.config.num_hidden_layers,
1136+
vocab_size=self.config.vocab_size,
1137+
seq_length=seq_length,
1138+
recompute=recompute,
1139+
recompute_granularity=self.config.recompute_granularity,
11851140
)
11861141

11871142
def get_input_embeddings(self):

paddlenlp/transformers/llama/modeling.py

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def swiglu(x, y=None):
7070
from .. import linear_utils
7171
from ..linear_utils import Linear
7272
from ..segment_parallel_utils import ReshardLayer
73+
from ..utils import caculate_llm_flops
7374
from .configuration import (
7475
LLAMA_PRETRAINED_INIT_CONFIGURATION,
7576
LLAMA_PRETRAINED_RESOURCE_FILES_MAP,
@@ -1469,82 +1470,36 @@ def __init__(self, config: LlamaConfig):
14691470
self.gradient_checkpointing = False
14701471

14711472
def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
1472-
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
1473-
hidden_size = self.config.hidden_size
1474-
intermediate_size = self.config.intermediate_size
1475-
layer_num = self.config.num_hidden_layers
1476-
vocab_size = self.config.vocab_size
1477-
14781473
if seq_length is None:
14791474
if hasattr(self.config, "seq_length"):
14801475
seq_length = self.config.seq_length
14811476
else:
14821477
seq_length = 2048
14831478

1484-
flops_per_transformer = 0
1485-
1486-
# qkvo matmul
1487-
flops_per_transformer += seq_length * hidden_size**2 * 4
1488-
# [b,s,h] [b,h,s] bs^2h
1489-
# [b,s,s] [b,s,h] bs^2h
1490-
# q_states * k_states + attn_weight * v_states
1491-
flops_per_transformer += seq_length**2 * hidden_size * 2
1492-
# swiglu, matmul + dot
1493-
flops_per_transformer += seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size
1494-
1495-
# final loggits
1496-
flops_loggits = seq_length * hidden_size * vocab_size
1497-
1498-
# 2 for mul + add in matmul
1499-
# 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y
1500-
# so, here got 6=2*(1+2)
1501-
return 6 * batch_size * (layer_num * flops_per_transformer + flops_loggits)
1479+
return caculate_llm_flops(
1480+
hidden_size=self.config.hidden_size,
1481+
intermediate_size=self.config.intermediate_size,
1482+
layer_num=self.config.num_hidden_layers,
1483+
vocab_size=self.config.vocab_size,
1484+
seq_length=seq_length,
1485+
recompute=False,
1486+
)
15021487

15031488
def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
1504-
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
1505-
hidden_size = self.config.hidden_size
1506-
intermediate_size = self.config.intermediate_size
1507-
layer_num = self.config.num_hidden_layers
1508-
vocab_size = self.config.vocab_size
1509-
15101489
if seq_length is None:
15111490
if hasattr(self.config, "seq_length"):
15121491
seq_length = self.config.seq_length
15131492
else:
15141493
seq_length = 2048
15151494

1516-
flops_per_transformer = 0
1517-
flops_recompute_transformer = 0
1518-
1519-
# qkvo matmul
1520-
flops_qkvo_matmul = seq_length * hidden_size**2 * 4
1521-
1522-
# [b,s,h] [b,h,s] bs^2h
1523-
# [b,s,s] [b,s,h] bs^2h
1524-
# q_states * k_states + attn_weight * v_states
1525-
flops_core_attn = seq_length**2 * hidden_size * 2
1526-
1527-
# swiglu, matmul + dot
1528-
flops_ffn = seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size
1529-
1530-
flops_per_transformer = flops_qkvo_matmul + flops_core_attn + flops_ffn
1531-
if recompute:
1532-
if self.config.recompute_granularity == "full":
1533-
flops_recompute_transformer = flops_per_transformer
1534-
if self.config.recompute_granularity == "full_attn":
1535-
flops_recompute_transformer = flops_qkvo_matmul + flops_core_attn
1536-
if self.config.recompute_granularity == "core_attn":
1537-
flops_recompute_transformer = flops_core_attn
1538-
1539-
# final loggits
1540-
flops_loggits = seq_length * hidden_size * vocab_size
1541-
1542-
# 2 for mul + add in matmul
1543-
# 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y
1544-
return (
1545-
2
1546-
* batch_size
1547-
* (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits)
1495+
return caculate_llm_flops(
1496+
hidden_size=self.config.hidden_size,
1497+
intermediate_size=self.config.intermediate_size,
1498+
layer_num=self.config.num_hidden_layers,
1499+
vocab_size=self.config.vocab_size,
1500+
seq_length=seq_length,
1501+
recompute=recompute,
1502+
recompute_granularity=self.config.recompute_granularity,
15481503
)
15491504

15501505
def get_input_embeddings(self):

0 commit comments

Comments
 (0)