Skip to content

Commit 64f00f0

Browse files
[LLM] Reconstruct fused transformer layers (#7186)
* reconstruct fused_transformer_layers * delete origin class * code refine
1 parent d0c85df commit 64f00f0

File tree

7 files changed

+545
-349
lines changed

7 files changed

+545
-349
lines changed

paddlenlp/experimental/transformers/bloom/modeling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from paddlenlp_ops import get_padding_offset
2222

2323
from paddlenlp.experimental.transformers.fused_transformer_layers import (
24-
FusedMultiTransformer,
24+
FusedMultiTransformerBase,
25+
FusedMultiTransformerConfig,
2526
)
2627
from paddlenlp.experimental.transformers.generation_utils import (
2728
GenerationInferenceModel,
@@ -112,7 +113,8 @@ def __init__(self, config):
112113
ffn1_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn1_bias".format(i)) for i in range(config.n_layer)]
113114
ffn2_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_weight".format(i)) for i in range(config.n_layer)]
114115
ffn2_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(config.n_layer)]
115-
self.transformer_block = FusedMultiTransformer(
116+
117+
transformer_config = FusedMultiTransformerConfig(
116118
self.embed_dim,
117119
self.n_head,
118120
4 * self.embed_dim,
@@ -133,6 +135,8 @@ def __init__(self, config):
133135
ffn2_weight_attrs=ffn2_weight_attrs,
134136
ffn2_bias_attrs=ffn2_bias_attrs,
135137
)
138+
139+
self.transformer_block = FusedMultiTransformerBase(transformer_config)
136140
self.cache_kvs = []
137141

138142
# Final Layer Norm

paddlenlp/experimental/transformers/chatglm/modeling.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from paddlenlp_ops import get_padding_offset
2121

2222
from paddlenlp.experimental.transformers.fused_transformer_layers import (
23-
FusedMultiTransformer,
23+
FusedMultiTransformerBase,
24+
FusedMultiTransformerConfig,
2425
)
2526
from paddlenlp.experimental.transformers.generation_utils import (
2627
GenerationInferenceModel,
@@ -183,7 +184,8 @@ def __init__(self, config: ChatGLMConfig):
183184
]
184185
ffn2_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(config.num_layers)]
185186
alpha = (2 * self.config.num_hidden_layers) ** 0.5
186-
self.transformer_block = FusedMultiTransformer(
187+
188+
transformer_config = FusedMultiTransformerConfig(
187189
config.hidden_size,
188190
config.num_attention_heads,
189191
4 * config.hidden_size,
@@ -209,6 +211,7 @@ def __init__(self, config: ChatGLMConfig):
209211
norm_type="layernorm",
210212
use_neox_rotary_style=True,
211213
)
214+
self.transformer_block = FusedMultiTransformerBase(transformer_config)
212215

213216
def remove_padding(self, input_ids, seq_lens_this_time):
214217
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)

0 commit comments

Comments
 (0)