Skip to content

Commit 08a2d39

Browse files
committed
add qwenvl second part
1 parent 04142e3 commit 08a2d39

File tree

3 files changed

+160
-11
lines changed

3 files changed

+160
-11
lines changed

llm/predictor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,11 +1348,16 @@ def create_predictor(
13481348
)
13491349
model.eval()
13501350
elif "qwen" in config.architectures[0].lower():
1351-
from paddlenlp.experimental.transformers import (
1352-
QWenForCausalLMInferenceModel,
1353-
)
1354-
1355-
model = QWenForCausalLMInferenceModel.from_pretrained(
1351+
if model_args.model_type == "qwen-img2txt":
1352+
# we use qwen for img2txt.
1353+
from paddlenlp.experimental.transformers import (
1354+
QWenForQWenVLInferenceModel as QWenInferenceModel,
1355+
)
1356+
else:
1357+
from paddlenlp.experimental.transformers import (
1358+
QWenForCausalLMInferenceModel as QWenInferenceModel,
1359+
)
1360+
model = QWenInferenceModel.from_pretrained(
13561361
predictor_args.model_name_or_path,
13571362
config=config,
13581363
dtype=predictor_args.dtype,

paddlenlp/experimental/transformers/qwen/modeling.py

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040
from paddlenlp.transformers.qwen.modeling import QWenLMHead, QWenPretrainingCriterion
4141

42-
__all__ = ["QWenForCausalLMInferenceModel"]
42+
__all__ = ["QWenForCausalLMInferenceModel", "QWenForQWenVLInferenceModel"]
4343

4444

4545
class FusedQWenRMSNorm(nn.Layer):
@@ -244,6 +244,18 @@ def remove_padding(self, input_ids, seq_lens_this_time):
244244
)
245245
return ids_remove_padding, padding_offset, cum_offsets
246246

247+
# This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py
248+
@staticmethod
249+
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
250+
batch_size = 1
251+
seq_len = 1
252+
if bos_token_id is None:
253+
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
254+
if encoder_output is not None:
255+
batch_size = encoder_output.shape[0]
256+
seq_len = encoder_output.shape[1]
257+
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id
258+
247259
def forward(
248260
self,
249261
input_ids=None,
@@ -270,17 +282,21 @@ def forward(
270282
elif input_ids is None and inputs_embeds is None:
271283
raise ValueError("You have to specify either input_ids or inputs_embeds")
272284

285+
# generate a fake input_ids according to inputs_embeds
286+
# this is usually occurred in img2txt multimodal model when first enter into this forward function.
287+
if input_ids is None and inputs_embeds is not None:
288+
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
289+
if inputs_embeds is not None:
290+
batch, seq_len, hidden_dim = inputs_embeds.shape
291+
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])
292+
273293
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
274294
output_hidden_states = (
275295
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
276296
)
277297
use_cache = use_cache if use_cache is not None else self.config.use_cache
278298
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
279299

280-
if inputs_embeds is not None:
281-
batch, seq_len, hidden_dim = inputs_embeds.shape
282-
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])
283-
284300
if past_key_values is None:
285301
past_key_values = tuple([None] * self.config.num_hidden_layers)
286302

@@ -502,3 +518,122 @@ def set_state_dict(self, state_dict):
502518
lm_head_weight = paddle.to_tensor(state_dict["lm_head.weight"], dtype=self.lm_head.weight.dtype)
503519
self.lm_head.weight.set_value(lm_head_weight)
504520
self.qwen.set_state_dict({k: state_dict[k] for k in state_dict.keys()})
521+
522+
523+
class QWenForQWenVLInferenceModel(QWenForCausalLMInferenceModel):
524+
"""
525+
This class is 99% like QWenForCausalLMInferenceModel.
526+
Used only for QWenVL's second part.
527+
"""
528+
529+
# This function corresponds to QWenVL's second part, only used for QWenVL.
530+
@paddle.no_grad()
531+
def generate_text_with_image_features(
532+
self,
533+
input_ids: paddle.Tensor,
534+
image_features: paddle.Tensor,
535+
img_pos: paddle.Tensor,
536+
attention_mask: paddle.Tensor,
537+
position_ids=None,
538+
penalty_score=None,
539+
frequency_score=None,
540+
presence_score=None,
541+
min_length=None,
542+
max_length=None,
543+
temperature=None,
544+
top_p=None,
545+
eos_token_id=None,
546+
seq_len_encoder=None,
547+
seq_len_decoder=None,
548+
step_idx=None,
549+
stop_flags=None,
550+
tgt_ids=None,
551+
tgt_pos=None,
552+
tgt_generation_mask=None,
553+
pre_ids=None,
554+
stop_nums=None,
555+
cache_kvs=[],
556+
inputs_embeds=None,
557+
**generate_kwargs
558+
) -> paddle.Tensor:
559+
inputs_embeds = self.qwen.wte(input_ids)
560+
inputs_embeds_dtype = inputs_embeds.dtype
561+
if inputs_embeds_dtype == paddle.bfloat16 or inputs_embeds_dtype == paddle.float16:
562+
inputs_embeds = paddle.cast(inputs_embeds, paddle.float32)
563+
image_features = paddle.cast(image_features, paddle.float32)
564+
565+
for idx, (i, a, b) in enumerate(img_pos):
566+
index = paddle.arange(a + 1, b).unsqueeze(-1)
567+
inputs_embeds[i] = paddle.scatter(inputs_embeds[i], index, image_features[idx])
568+
569+
if inputs_embeds_dtype == paddle.bfloat16 or inputs_embeds_dtype == paddle.float16:
570+
inputs_embeds = paddle.cast(inputs_embeds, inputs_embeds_dtype)
571+
572+
outputs = self.generate(
573+
inputs_embeds=inputs_embeds,
574+
attention_mask=attention_mask,
575+
position_ids=position_ids,
576+
penalty_score=penalty_score,
577+
frequency_score=frequency_score,
578+
presence_score=presence_score,
579+
min_length=min_length,
580+
max_length=max_length,
581+
temperature=temperature,
582+
top_p=top_p,
583+
eos_token_id=eos_token_id,
584+
seq_len_encoder=seq_len_encoder,
585+
seq_len_decoder=seq_len_decoder,
586+
step_idx=step_idx,
587+
stop_flags=stop_flags,
588+
tgt_ids=tgt_ids,
589+
tgt_pos=tgt_pos,
590+
tgt_generation_mask=tgt_generation_mask,
591+
pre_ids=pre_ids,
592+
stop_nums=stop_nums,
593+
cache_kvs=cache_kvs,
594+
)
595+
return outputs
596+
597+
# rewrite to_static function in generation_utils.py
598+
def to_static(self, output_path: str, config: dict):
599+
dtype = config.get("dtype", paddle.get_default_dtype())
600+
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
601+
input_spec = [
602+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), # input_ids
603+
paddle.static.InputSpec(
604+
shape=[None, None, None], dtype="float32", name="image_features"
605+
), # image_features
606+
paddle.static.InputSpec(shape=[None, 3], dtype="int64", name="img_pos"), # img_pos
607+
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
608+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
609+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
610+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
611+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
612+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
613+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
614+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
615+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
616+
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
617+
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
618+
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
619+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
620+
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
621+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
622+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
623+
paddle.static.InputSpec(
624+
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
625+
), # tgt_generation_mask
626+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
627+
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
628+
[
629+
paddle.static.InputSpec(
630+
shape=shape,
631+
dtype=dtype,
632+
name="cache_kvs_{}".format(i),
633+
)
634+
for i, shape in enumerate(cache_kvs_shapes)
635+
], # cache_kvs
636+
]
637+
638+
model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
639+
paddle.jit.save(model, output_path, skip_prune_program=True)

paddlenlp/transformers/qwen/configuration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(
4747
tensor_parallel_output=True,
4848
no_bias=True,
4949
tie_word_embeddings=False,
50+
pad_token_id=0,
51+
bos_token_id=1,
52+
eos_token_id=2,
5053
**kwargs,
5154
):
5255
self.vocab_size = vocab_size
@@ -72,4 +75,10 @@ def __init__(
7275
self.use_fused_rope = use_fused_rope
7376
self.no_bias = no_bias
7477

75-
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
78+
super().__init__(
79+
pad_token_id=pad_token_id,
80+
bos_token_id=bos_token_id,
81+
eos_token_id=eos_token_id,
82+
tie_word_embeddings=tie_word_embeddings,
83+
**kwargs,
84+
)

0 commit comments

Comments
 (0)