Skip to content

Commit d7023f9

Browse files
committed
improve code & add comments
1 parent 08a2d39 commit d7023f9

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

paddlenlp/experimental/transformers/qwen/modeling.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ def remove_padding(self, input_ids, seq_lens_this_time):
244244
)
245245
return ids_remove_padding, padding_offset, cum_offsets
246246

247-
# This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py
247+
# This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py,
248+
# it is used to generate fake input_ids according to inputs_embeds length.
248249
@staticmethod
249250
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
250251
batch_size = 1
@@ -254,7 +255,7 @@ def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
254255
if encoder_output is not None:
255256
batch_size = encoder_output.shape[0]
256257
seq_len = encoder_output.shape[1]
257-
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id
258+
return paddle.full([batch_size, seq_len], bos_token_id, dtype="int64")
258259

259260
def forward(
260261
self,
@@ -558,15 +559,15 @@ def generate_text_with_image_features(
558559
) -> paddle.Tensor:
559560
inputs_embeds = self.qwen.wte(input_ids)
560561
inputs_embeds_dtype = inputs_embeds.dtype
561-
if inputs_embeds_dtype == paddle.bfloat16 or inputs_embeds_dtype == paddle.float16:
562+
if inputs_embeds_dtype != paddle.float32:
562563
inputs_embeds = paddle.cast(inputs_embeds, paddle.float32)
563564
image_features = paddle.cast(image_features, paddle.float32)
564565

565-
for idx, (i, a, b) in enumerate(img_pos):
566-
index = paddle.arange(a + 1, b).unsqueeze(-1)
566+
for idx, (i, image_start_idx, image_end_idx) in enumerate(img_pos):
567+
index = paddle.arange(image_start_idx + 1, image_end_idx).unsqueeze(-1)
567568
inputs_embeds[i] = paddle.scatter(inputs_embeds[i], index, image_features[idx])
568569

569-
if inputs_embeds_dtype == paddle.bfloat16 or inputs_embeds_dtype == paddle.float16:
570+
if inputs_embeds_dtype != paddle.float32:
570571
inputs_embeds = paddle.cast(inputs_embeds, inputs_embeds_dtype)
571572

572573
outputs = self.generate(

0 commit comments

Comments
 (0)