39
39
)
40
40
from paddlenlp .transformers .qwen .modeling import QWenLMHead , QWenPretrainingCriterion
41
41
42
- __all__ = ["QWenForCausalLMInferenceModel" ]
42
+ __all__ = ["QWenForCausalLMInferenceModel" , "QWenForQWenVLInferenceModel" ]
43
43
44
44
45
45
class FusedQWenRMSNorm (nn .Layer ):
@@ -244,6 +244,18 @@ def remove_padding(self, input_ids, seq_lens_this_time):
244
244
)
245
245
return ids_remove_padding , padding_offset , cum_offsets
246
246
247
+ # This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py
248
+ @staticmethod
249
+ def prepare_input_ids_for_generation (bos_token_id , encoder_output = None ):
250
+ batch_size = 1
251
+ seq_len = 1
252
+ if bos_token_id is None :
253
+ raise ValueError ("`bos_token_id` should be defined when no " "`input_ids` are provided." )
254
+ if encoder_output is not None :
255
+ batch_size = encoder_output .shape [0 ]
256
+ seq_len = encoder_output .shape [1 ]
257
+ return paddle .ones ([batch_size , seq_len ], dtype = "int64" ) * bos_token_id
258
+
247
259
def forward (
248
260
self ,
249
261
input_ids = None ,
@@ -270,17 +282,21 @@ def forward(
270
282
elif input_ids is None and inputs_embeds is None :
271
283
raise ValueError ("You have to specify either input_ids or inputs_embeds" )
272
284
285
+ # generate a fake input_ids according to inputs_embeds
286
+ # this is usually occurred in img2txt multimodal model when first enter into this forward function.
287
+ if input_ids is None and inputs_embeds is not None :
288
+ input_ids = self .prepare_input_ids_for_generation (self .config .bos_token_id , inputs_embeds )
289
+ if inputs_embeds is not None :
290
+ batch , seq_len , hidden_dim = inputs_embeds .shape
291
+ inputs_embeds = inputs_embeds .reshape ([batch * seq_len , hidden_dim ])
292
+
273
293
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
274
294
output_hidden_states = (
275
295
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
276
296
)
277
297
use_cache = use_cache if use_cache is not None else self .config .use_cache
278
298
return_dict = return_dict if return_dict is not None else self .config .use_return_dict
279
299
280
- if inputs_embeds is not None :
281
- batch , seq_len , hidden_dim = inputs_embeds .shape
282
- inputs_embeds = inputs_embeds .reshape ([batch * seq_len , hidden_dim ])
283
-
284
300
if past_key_values is None :
285
301
past_key_values = tuple ([None ] * self .config .num_hidden_layers )
286
302
@@ -502,3 +518,122 @@ def set_state_dict(self, state_dict):
502
518
lm_head_weight = paddle .to_tensor (state_dict ["lm_head.weight" ], dtype = self .lm_head .weight .dtype )
503
519
self .lm_head .weight .set_value (lm_head_weight )
504
520
self .qwen .set_state_dict ({k : state_dict [k ] for k in state_dict .keys ()})
521
+
522
+
523
+ class QWenForQWenVLInferenceModel (QWenForCausalLMInferenceModel ):
524
+ """
525
+ This class is 99% like QWenForCausalLMInferenceModel.
526
+ Used only for QWenVL's second part.
527
+ """
528
+
529
+ # This function corresponds to QWenVL's second part, only used for QWenVL.
530
+ @paddle .no_grad ()
531
+ def generate_text_with_image_features (
532
+ self ,
533
+ input_ids : paddle .Tensor ,
534
+ image_features : paddle .Tensor ,
535
+ img_pos : paddle .Tensor ,
536
+ attention_mask : paddle .Tensor ,
537
+ position_ids = None ,
538
+ penalty_score = None ,
539
+ frequency_score = None ,
540
+ presence_score = None ,
541
+ min_length = None ,
542
+ max_length = None ,
543
+ temperature = None ,
544
+ top_p = None ,
545
+ eos_token_id = None ,
546
+ seq_len_encoder = None ,
547
+ seq_len_decoder = None ,
548
+ step_idx = None ,
549
+ stop_flags = None ,
550
+ tgt_ids = None ,
551
+ tgt_pos = None ,
552
+ tgt_generation_mask = None ,
553
+ pre_ids = None ,
554
+ stop_nums = None ,
555
+ cache_kvs = [],
556
+ inputs_embeds = None ,
557
+ ** generate_kwargs
558
+ ) -> paddle .Tensor :
559
+ inputs_embeds = self .qwen .wte (input_ids )
560
+ inputs_embeds_dtype = inputs_embeds .dtype
561
+ if inputs_embeds_dtype == paddle .bfloat16 or inputs_embeds_dtype == paddle .float16 :
562
+ inputs_embeds = paddle .cast (inputs_embeds , paddle .float32 )
563
+ image_features = paddle .cast (image_features , paddle .float32 )
564
+
565
+ for idx , (i , a , b ) in enumerate (img_pos ):
566
+ index = paddle .arange (a + 1 , b ).unsqueeze (- 1 )
567
+ inputs_embeds [i ] = paddle .scatter (inputs_embeds [i ], index , image_features [idx ])
568
+
569
+ if inputs_embeds_dtype == paddle .bfloat16 or inputs_embeds_dtype == paddle .float16 :
570
+ inputs_embeds = paddle .cast (inputs_embeds , inputs_embeds_dtype )
571
+
572
+ outputs = self .generate (
573
+ inputs_embeds = inputs_embeds ,
574
+ attention_mask = attention_mask ,
575
+ position_ids = position_ids ,
576
+ penalty_score = penalty_score ,
577
+ frequency_score = frequency_score ,
578
+ presence_score = presence_score ,
579
+ min_length = min_length ,
580
+ max_length = max_length ,
581
+ temperature = temperature ,
582
+ top_p = top_p ,
583
+ eos_token_id = eos_token_id ,
584
+ seq_len_encoder = seq_len_encoder ,
585
+ seq_len_decoder = seq_len_decoder ,
586
+ step_idx = step_idx ,
587
+ stop_flags = stop_flags ,
588
+ tgt_ids = tgt_ids ,
589
+ tgt_pos = tgt_pos ,
590
+ tgt_generation_mask = tgt_generation_mask ,
591
+ pre_ids = pre_ids ,
592
+ stop_nums = stop_nums ,
593
+ cache_kvs = cache_kvs ,
594
+ )
595
+ return outputs
596
+
597
+ # rewrite to_static function in generation_utils.py
598
+ def to_static (self , output_path : str , config : dict ):
599
+ dtype = config .get ("dtype" , paddle .get_default_dtype ())
600
+ cache_kvs_shapes = self .get_cache_kvs_shape (self .config , max_length = config .get ("max_length" , None ))
601
+ input_spec = [
602
+ paddle .static .InputSpec (shape = [None , None ], dtype = "int64" , name = "input_ids" ), # input_ids
603
+ paddle .static .InputSpec (
604
+ shape = [None , None , None ], dtype = "float32" , name = "image_features"
605
+ ), # image_features
606
+ paddle .static .InputSpec (shape = [None , 3 ], dtype = "int64" , name = "img_pos" ), # img_pos
607
+ paddle .static .InputSpec (shape = [None , None ], dtype = dtype , name = "attention_mask" ), # attention_mask
608
+ paddle .static .InputSpec (shape = [None , None ], dtype = "int64" , name = "position_ids" ), # position_ids
609
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "penalty_score" ), # penalty_score
610
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "frequency_score" ), # frequency_score
611
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "presence_score" ), # presence_score
612
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "min_length" ), # min_decode_length
613
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "max_length" ), # max_decode_length
614
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "temperature" ), # temperature
615
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "top_p" ), # top_p
616
+ paddle .static .InputSpec (shape = [None ], dtype = "int64" , name = "eos_token_id" ), # eos_token_id
617
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int32" , name = "seq_len_encoder" ), # seq_len_encoder
618
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int32" , name = "seq_len_decoder" ), # seq_len_decoder
619
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "step_idx" ), # step_idx
620
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "bool" , name = "stop_flags" ), # stop_flags
621
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "tgt_ids" ), # tgt_ids
622
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "tgt_pos" ), # tgt_pos
623
+ paddle .static .InputSpec (
624
+ shape = [None , 1 , 1 , None ], dtype = dtype , name = "tgt_generation_mask"
625
+ ), # tgt_generation_mask
626
+ paddle .static .InputSpec (shape = [None , None ], dtype = "int64" , name = "pre_ids" ), # pre_ids
627
+ paddle .static .InputSpec (shape = [1 ], dtype = "int64" , name = "stop_nums" ), # stop_nums
628
+ [
629
+ paddle .static .InputSpec (
630
+ shape = shape ,
631
+ dtype = dtype ,
632
+ name = "cache_kvs_{}" .format (i ),
633
+ )
634
+ for i , shape in enumerate (cache_kvs_shapes )
635
+ ], # cache_kvs
636
+ ]
637
+
638
+ model = paddle .jit .to_static (self .generate_text_with_image_features , input_spec = input_spec )
639
+ paddle .jit .save (model , output_path , skip_prune_program = True )
0 commit comments