Skip to content

Commit e49842c

Browse files
sijunhewtmlon
andauthored
chatglm2 beam search fix (#7012)
* chatglm2 beam search fix * changes --------- Co-authored-by: 刘汀 <wtmlon@foxmail.com>
1 parent 45d4ee8 commit e49842c

File tree

3 files changed

+17
-23
lines changed

3 files changed

+17
-23
lines changed

paddlenlp/generation/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,9 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
14951495

14961496
return input_ids[:, origin_len:], scores
14971497

1498+
def reorder_cache(self, cache, beam_idx):
1499+
cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)
1500+
14981501
def beam_search(
14991502
self,
15001503
input_ids,
@@ -1623,9 +1626,7 @@ def beam_search(
16231626
cache_name = "cache" if "cache" in model_kwargs else "past_key_values"
16241627
if model_kwargs[cache_name] is not None:
16251628
# reorder the cache
1626-
model_kwargs[cache_name] = map_structure(
1627-
lambda x: paddle.index_select(x, beam_idx), model_kwargs[cache_name]
1628-
)
1629+
self.reorder_cache(model_kwargs[cache_name], beam_idx)
16291630

16301631
pred_ids, scores = beam_scorer.finalize(
16311632
input_ids,
@@ -1773,9 +1774,7 @@ def group_beam_search(
17731774
cache_name = "cache" if "cache" in model_kwargs else "past_key_values"
17741775
if model_kwargs[cache_name] is not None:
17751776
# reorder the cache
1776-
model_kwargs[cache_name] = map_structure(
1777-
lambda x: paddle.index_select(x, reordering_indices), model_kwargs[cache_name]
1778-
)
1777+
self.reorder_cache(model_kwargs[cache_name], beam_idx)
17791778

17801779
pred_ids, scores = beam_scorer.finalize(
17811780
input_ids,

paddlenlp/transformers/chatglm_v2/modeling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import paddle.nn as nn
2222
import paddle.nn.functional as F
2323
from paddle.distributed.fleet.utils import recompute
24+
from paddle.utils import map_structure
2425

2526
from .. import PretrainedModel, register_base_model
2627
from ..model_outputs import (
@@ -765,6 +766,9 @@ def __init__(self, config: ChatGLMv2Config):
765766
self.max_sequence_length = config.max_sequence_length
766767
self.chatglm_v2 = ChatGLMv2Model(config)
767768

769+
def reorder_cache(self, cache: paddle.Tensor, beam_idx):
770+
cache = map_structure(lambda x: paddle.index_select(x, beam_idx, axis=1), cache)
771+
768772
def update_model_kwargs_for_generation(
769773
self,
770774
outputs: ModelOutput,

tests/transformers/chatglm_v2/test_modeling.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616

1717
import paddle
18+
from parameterized import parameterized_class
1819

1920
from paddlenlp.transformers import ChatGLMv2Config, ChatGLMv2ForCausalLM, ChatGLMv2Model
2021
from tests.transformers.test_generation_utils import GenerationTesterMixin
@@ -24,8 +25,6 @@
2425
random_attention_mask,
2526
)
2627

27-
# from parameterized import parameterized_class
28-
2928

3029
class ChatGLMv2Tester:
3130
def __init__(
@@ -172,13 +171,13 @@ def create_and_check_model_attention_mask(self, config: ChatGLMv2Config, input_i
172171
self.parent.assertTrue((result_2d[attn_mask_2d] == result_no_attention_mask[attn_mask_2d]).all())
173172

174173

175-
# @parameterized_class(
176-
# ("return_dict", "use_labels"),
177-
# [
178-
# [False, True],
179-
# [True, False],
180-
# ],
181-
# )
174+
@parameterized_class(
175+
("return_dict", "use_labels"),
176+
[
177+
[False, True],
178+
[True, False],
179+
],
180+
)
182181
class ChatGLMv2Test(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
183182
base_model_class = ChatGLMv2Model
184183
return_dict: bool = True
@@ -220,14 +219,6 @@ def test_model_attention_mask(self):
220219
config_and_inputs = self.model_tester.prepare_config_and_inputs()
221220
self.model_tester.create_and_check_model_attention_mask(*config_and_inputs)
222221

223-
# chatglm_v2 cannot use beam search temporarily
224-
def test_beam_search_generate(self):
225-
pass
226-
227-
# chatglm_v2 cannot use group beam search temporarily
228-
def test_group_beam_search_generate(self):
229-
pass
230-
231222

232223
# class ChatGLMV2GenerationD2STest(GenerationD2STestMixin, unittest.TestCase):
233224
# internal_testing_model = "__internal_testing__/tiny-random-chatglm2"

0 commit comments

Comments
 (0)