@@ -244,7 +244,8 @@ def remove_padding(self, input_ids, seq_lens_this_time):
244
244
)
245
245
return ids_remove_padding , padding_offset , cum_offsets
246
246
247
- # This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py
247
+ # This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py,
248
+ # it is used to generate fake input_ids according to inputs_embeds length.
248
249
@staticmethod
249
250
def prepare_input_ids_for_generation (bos_token_id , encoder_output = None ):
250
251
batch_size = 1
@@ -254,7 +255,7 @@ def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
254
255
if encoder_output is not None :
255
256
batch_size = encoder_output .shape [0 ]
256
257
seq_len = encoder_output .shape [1 ]
257
- return paddle .ones ([batch_size , seq_len ], dtype = "int64" ) * bos_token_id
258
+ return paddle .full ([batch_size , seq_len ], bos_token_id , dtype = "int64" )
258
259
259
260
def forward (
260
261
self ,
@@ -558,15 +559,15 @@ def generate_text_with_image_features(
558
559
) -> paddle .Tensor :
559
560
inputs_embeds = self .qwen .wte (input_ids )
560
561
inputs_embeds_dtype = inputs_embeds .dtype
561
- if inputs_embeds_dtype == paddle .bfloat16 or inputs_embeds_dtype == paddle . float16 :
562
+ if inputs_embeds_dtype != paddle .float32 :
562
563
inputs_embeds = paddle .cast (inputs_embeds , paddle .float32 )
563
564
image_features = paddle .cast (image_features , paddle .float32 )
564
565
565
- for idx , (i , a , b ) in enumerate (img_pos ):
566
- index = paddle .arange (a + 1 , b ).unsqueeze (- 1 )
566
+ for idx , (i , image_start_idx , image_end_idx ) in enumerate (img_pos ):
567
+ index = paddle .arange (image_start_idx + 1 , image_end_idx ).unsqueeze (- 1 )
567
568
inputs_embeds [i ] = paddle .scatter (inputs_embeds [i ], index , image_features [idx ])
568
569
569
- if inputs_embeds_dtype == paddle .bfloat16 or inputs_embeds_dtype == paddle . float16 :
570
+ if inputs_embeds_dtype != paddle .float32 :
570
571
inputs_embeds = paddle .cast (inputs_embeds , inputs_embeds_dtype )
571
572
572
573
outputs = self .generate (
0 commit comments