Skip to content

Commit db38937

Browse files
authored
[Inference] Change norm outputs in dynamic and static mode(#9569)
1 parent c96260b commit db38937

File tree

6 files changed

+10
-27
lines changed

6 files changed

+10
-27
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def get_weight_create_dype(self):
741741

742742
def compute_layernorm_before_qkv(self, src, i):
743743
if i == 0:
744-
ln_out = self.norm_func(src, self.ln_scales[i], self.ln_biases[i], self._epsilon, begin_norm_axis=1)
744+
ln_out = self.norm_func(src, self.ln_scales[i], self.ln_biases[i], self._epsilon, begin_norm_axis=1)[0]
745745
else:
746746
ln_out = src
747747

@@ -1918,7 +1918,7 @@ def compute_layernorm_before_qkv(self, src, i):
19181918
quant_round_type=self.quant_round_type,
19191919
quant_max_bound=self.quant_max_bound,
19201920
quant_min_bound=self.quant_min_bound,
1921-
)
1921+
)[0]
19221922
else:
19231923
ln_out = src
19241924

@@ -2617,9 +2617,7 @@ def compute_layernorm_before_qkv(self, src, i):
26172617
quant_round_type=1,
26182618
quant_max_bound=self.config.quant_max_bound,
26192619
quant_min_bound=self.config.quant_min_bound,
2620-
)
2621-
if in_dynamic_mode():
2622-
ln_out = ln_out[0]
2620+
)[0]
26232621
else:
26242622
ln_out = src
26252623

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,9 @@ def __init__(self, config):
8989
self.config = config
9090

9191
def forward(self, hidden_states):
92-
result = paddle.incubate.nn.functional.fused_rms_norm(
92+
return paddle.incubate.nn.functional.fused_rms_norm(
9393
hidden_states, self.weight, None, self.variance_epsilon, begin_norm_axis=1
94-
)
95-
if isinstance(result, tuple):
96-
return result[0]
97-
return result
94+
)[0]
9895

9996

10097
class LLamaAvxLMHead(nn.Layer):

paddlenlp/experimental/transformers/mixtral/modeling.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,9 @@ def __init__(self, config):
7777
self.config = config
7878

7979
def forward(self, hidden_states):
80-
result = paddle.incubate.nn.functional.fused_rms_norm(
80+
return paddle.incubate.nn.functional.fused_rms_norm(
8181
hidden_states, self.weight, None, self.variance_epsilon, begin_norm_axis=1
82-
)
83-
if isinstance(result, tuple):
84-
return result[0]
85-
return result
82+
)[0]
8683

8784

8885
@register_base_model

paddlenlp/experimental/transformers/qwen/modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ def __init__(self, config):
5353
)
5454

5555
def forward(self, x):
56-
result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)
57-
if isinstance(result, tuple):
58-
return result[0]
59-
return result
56+
return paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)[0]
6057

6158

6259
@register_base_model

paddlenlp/experimental/transformers/qwen2/modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,7 @@ def __init__(self, config):
7878
)
7979

8080
def forward(self, x):
81-
result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)
82-
if isinstance(result, tuple):
83-
return result[0]
84-
return result
81+
return paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)[0]
8582

8683

8784
@register_base_model

paddlenlp/experimental/transformers/qwen2_moe/modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ def __init__(self, config):
6565
)
6666

6767
def forward(self, x):
68-
result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)
69-
if isinstance(result, tuple):
70-
return result[0]
71-
return result
68+
return paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)[0]
7269

7370

7471
@register_base_model

0 commit comments

Comments
 (0)