Skip to content

Commit ca1a475

Browse files
committed
fix norm outputs in dynamic and static mode
1 parent 14ed2a2 commit ca1a475

File tree

6 files changed

+10
-28
lines changed

6 files changed

+10
-28
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def get_weight_create_dype(self):
741741

742742
def compute_layernorm_before_qkv(self, src, i):
743743
if i == 0:
744-
ln_out = self.norm_func(src, self.ln_scales[i], self.ln_biases[i], self._epsilon, begin_norm_axis=1)
744+
ln_out = self.norm_func(src, self.ln_scales[i], self.ln_biases[i], self._epsilon, begin_norm_axis=1)[0]
745745
else:
746746
ln_out = src
747747

@@ -1918,7 +1918,7 @@ def compute_layernorm_before_qkv(self, src, i):
19181918
quant_round_type=self.quant_round_type,
19191919
quant_max_bound=self.quant_max_bound,
19201920
quant_min_bound=self.quant_min_bound,
1921-
)
1921+
)[0]
19221922
else:
19231923
ln_out = src
19241924

@@ -2617,9 +2617,7 @@ def compute_layernorm_before_qkv(self, src, i):
26172617
quant_round_type=1,
26182618
quant_max_bound=self.config.quant_max_bound,
26192619
quant_min_bound=self.config.quant_min_bound,
2620-
)
2621-
if in_dynamic_mode():
2622-
ln_out = ln_out[0]
2620+
)[0]
26232621
else:
26242622
ln_out = src
26252623

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,9 @@ def __init__(self, config):
9999
self.config = config
100100

101101
def forward(self, hidden_states):
102-
result = paddle.incubate.nn.functional.fused_rms_norm(
102+
return paddle.incubate.nn.functional.fused_rms_norm(
103103
hidden_states, self.weight, None, self.variance_epsilon, begin_norm_axis=1
104-
)
105-
if isinstance(result, tuple):
106-
return result[0]
107-
return result
104+
)[0]
108105

109106

110107
class LLamaAvxLMHead(nn.Layer):

paddlenlp/experimental/transformers/mixtral/modeling.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,9 @@ def __init__(self, config):
7777
self.config = config
7878

7979
def forward(self, hidden_states):
80-
result = paddle.incubate.nn.functional.fused_rms_norm(
80+
return paddle.incubate.nn.functional.fused_rms_norm(
8181
hidden_states, self.weight, None, self.variance_epsilon, begin_norm_axis=1
82-
)
83-
if isinstance(result, tuple):
84-
return result[0]
85-
return result
82+
)[0]
8683

8784

8885
@register_base_model

paddlenlp/experimental/transformers/qwen/modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ def __init__(self, config):
5353
)
5454

5555
def forward(self, x):
56-
result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)
57-
if isinstance(result, tuple):
58-
return result[0]
59-
return result
56+
return paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)[0]
6057

6158

6259
@register_base_model

paddlenlp/experimental/transformers/qwen2/modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,7 @@ def __init__(self, config):
7878
)
7979

8080
def forward(self, x):
81-
result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)
82-
if isinstance(result, tuple):
83-
return result[0]
84-
return result
81+
return paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)[0]
8582

8683

8784
@register_base_model

paddlenlp/experimental/transformers/qwen2_moe/modeling.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,8 @@ def __init__(self, config):
6363
dtype=paddle.get_default_dtype(),
6464
default_initializer=nn.initializer.Constant(1.0),
6565
)
66-
6766
def forward(self, x):
68-
result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)
69-
if isinstance(result, tuple):
70-
return result[0]
71-
return result
67+
return paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)[0]
7268

7369

7470
@register_base_model

0 commit comments

Comments
 (0)