Skip to content

Commit 5c57015

Browse files
modify Paddlemix qwen dytostatic (#8869)
* modify api for pir * modify api for pir * pass none for while * modify ci test
1 parent 678843e commit 5c57015

File tree

4 files changed

+16
-6
lines changed

4 files changed

+16
-6
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import paddle
1717
import paddle.distributed as dist
18-
from paddle.framework import LayerHelper, core, in_dynamic_mode
18+
from paddle.framework import LayerHelper, core, in_dynamic_mode, in_dynamic_or_pir_mode
1919
from paddle.incubate.nn.functional import (
2020
fused_layer_norm,
2121
fused_rms_norm,
@@ -88,7 +88,8 @@ def fused_act_bias_wrapper(
8888
quant_max_bound=0,
8989
quant_min_bound=0,
9090
):
91-
if in_dynamic_mode():
91+
if in_dynamic_or_pir_mode():
92+
9293
return paddle._C_ops.fused_bias_act(
9394
x,
9495
bias,

paddlenlp/experimental/transformers/generation_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def generate(
183183
inputs_embeds=inputs_embeds,
184184
**model_kwargs,
185185
)
186+
186187
return ret
187188

188189
def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, eos_token_id, model_kwargs):

paddlenlp/generation/logits_process.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,10 @@ def __call__(self, input_ids, scores):
291291

292292

293293
def TopKProcess(probs: paddle.Tensor, top_k: int, min_tokens_to_keep: int):
294-
top_k = min(max(top_k, min_tokens_to_keep), probs.shape[-1])
294+
top_k = paddle.minimum(
295+
paddle.maximum(paddle.to_tensor(top_k), paddle.to_tensor(min_tokens_to_keep)),
296+
paddle.to_tensor(probs.shape[-1]),
297+
)
295298
# Remove all tokens with a probability less than the last token of the top-k
296299
# cast to float16 to support generation & d2s
297300
if probs.dtype == paddle.bfloat16:

tests/transformers/test_modeling_common.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -967,8 +967,10 @@ def test_to_static_use_top_k(self):
967967
use_top_p=False,
968968
),
969969
)
970-
971-
model_path = os.path.join(tempdir, "model.pdmodel")
970+
if paddle.framework.use_pir_api():
971+
model_path = os.path.join(tempdir, "model.json")
972+
else:
973+
model_path = os.path.join(tempdir, "model.pdmodel")
972974
params_path = os.path.join(tempdir, "model.pdiparams")
973975
config = paddle.inference.Config(model_path, params_path)
974976

@@ -1036,7 +1038,10 @@ def test_to_static_use_top_p(self):
10361038
),
10371039
)
10381040

1039-
model_path = os.path.join(tempdir, "model.pdmodel")
1041+
if paddle.framework.use_pir_api():
1042+
model_path = os.path.join(tempdir, "model.json")
1043+
else:
1044+
model_path = os.path.join(tempdir, "model.pdmodel")
10401045
params_path = os.path.join(tempdir, "model.pdiparams")
10411046
config = paddle.inference.Config(model_path, params_path)
10421047

0 commit comments

Comments
 (0)