Skip to content

Commit 2b55d7a

Browse files
[LLM] fix Qwen-7b-Chat precision issue
fix qwen-7b0chat model batch inference precision issue add Qwen-7B-Chat to PaddleNLP unit test
1 parent 04142e3 commit 2b55d7a

File tree

4 files changed

+5
-3
lines changed

4 files changed

+5
-3
lines changed

llm/predictor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,7 @@ def create_predictor(
11871187

11881188
tensor_parallel_rank, tensor_parallel_degree = init_dist_env()
11891189
if not predictor_args.inference_model:
1190+
tokenizer.padding_side = "left"
11901191
if predictor_args.mode == "dynamic":
11911192
if model_args.model_type == "gpt-3":
11921193
sys.path.append("./gpt-3")

paddlenlp/generation/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,10 +1059,8 @@ def greedy_search(
10591059

10601060
# pre-process distribution
10611061
next_token_logits = self.adjust_logits_during_generation(next_token_logits)
1062-
next_tokens_scores = logits_processors(input_ids, next_token_logits)
1062+
probs = logits_processors(input_ids, next_token_logits)
10631063
# greedy
1064-
probs = F.softmax(next_tokens_scores)
1065-
probs = paddle.log(probs)
10661064
next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
10671065
next_scores = paddle.index_sample(probs, next_tokens)
10681066

paddlenlp/transformers/qwen/tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
self,
7171
vocab_file,
7272
errors="replace",
73+
padding_side="left",
7374
**kwargs,
7475
):
7576
super().__init__(**kwargs)

tests/llm/test_predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ChatGLMForCausalLM,
2727
ChatGLMv2ForCausalLM,
2828
LlamaForCausalLM,
29+
QWenForCausalLM,
2930
)
3031
from paddlenlp.utils.downloader import (
3132
COMMUNITY_MODEL_PREFIX,
@@ -43,6 +44,7 @@
4344
["__internal_testing__/tiny-fused-bloom", BloomForCausalLM],
4445
["__internal_testing__/tiny-fused-chatglm", ChatGLMForCausalLM],
4546
["__internal_testing__/tiny-fused-chatglm2", ChatGLMv2ForCausalLM],
47+
["__internal_testing__/tiny-fused-qwen", QWenForCausalLM],
4648
],
4749
)
4850
class PredictorTest(LLMTest, unittest.TestCase):

0 commit comments

Comments
 (0)