Skip to content

Commit 45d4ee8

Browse files
authored
Fix bert error in SOT (#6942)
* Fix bert error in SOT * Format code * Format code
1 parent d83032d commit 45d4ee8

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

paddlenlp/transformers/model_outputs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,8 @@ def _transformer_encoder_fwd(
356356
)
357357

358358

359+
_transformer_encoder_fwd.__name__ = "forward"
360+
_transformer_encoder_layer_fwd.__name__ = "forward"
359361
# patches of paddle.nn.Transformer to get all hidden_states and attentions
360362
paddle.nn.TransformerEncoderLayer.forward = _transformer_encoder_layer_fwd
361363
paddle.nn.TransformerDecoderLayer.forward = _transformer_decoder_layer_fwd

paddlenlp/transformers/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def adapt_stale_fwd_patch(self, name, value):
141141
# NOTE(guosheng): In dygraph to static, `layer.forward` would be patched
142142
# by an instance of `StaticFunction`. And use string compare to avoid to
143143
# import fluid.
144-
if type(value).__name__.endswith("StaticFunction"):
144+
if type(value).__name__.endswith("StaticFunction") or self.forward.__class__.__name__.endswith(
145+
"StaticFunction"
146+
):
145147
return value
146148
if hasattr(inspect, "getfullargspec"):
147149
(

0 commit comments

Comments
 (0)