Skip to content

[LLM]support QWenVL second part #7808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,11 +1386,16 @@ def create_predictor(
)
model.eval()
elif "qwen" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
QWenForCausalLMInferenceModel,
)

model = QWenForCausalLMInferenceModel.from_pretrained(
if model_args.model_type == "qwen-img2txt":
# we use qwen for img2txt.
from paddlenlp.experimental.transformers import (
QWenForQWenVLInferenceModel as QWenInferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
QWenForCausalLMInferenceModel as QWenInferenceModel,
)
model = QWenInferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
Expand Down
146 changes: 141 additions & 5 deletions paddlenlp/experimental/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from paddlenlp.transformers.qwen.modeling import QWenLMHead, QWenPretrainingCriterion

__all__ = ["QWenForCausalLMInferenceModel"]
__all__ = ["QWenForCausalLMInferenceModel", "QWenForQWenVLInferenceModel"]


class FusedQWenRMSNorm(nn.Layer):
Expand Down Expand Up @@ -244,6 +244,19 @@ def remove_padding(self, input_ids, seq_lens_this_time):
)
return ids_remove_padding, padding_offset, cum_offsets

# This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py,
# it is used to generate fake input_ids according to inputs_embeds length.
@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
batch_size = 1
seq_len = 1
if bos_token_id is None:
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有传入 input_ids,所以没必要描述与 input_ids 之间的关联吧,你这个 error message 会给开发者一头雾水。

Copy link
Contributor Author

@DanGuge DanGuge Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是在img2txt的模型中,在第一次进入QWenInferenceModel的forward的时候,generate中不会传入input_ids,所以需要根据inputs_embeds生成一个fake的input_ids,input_ids根据bos_token_id来生成

if encoder_output is not None:
batch_size = encoder_output.shape[0]
seq_len = encoder_output.shape[1]
return paddle.full([batch_size, seq_len], bos_token_id, dtype="int64")

def forward(
self,
input_ids=None,
Expand All @@ -270,17 +283,21 @@ def forward(
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

# generate a fake input_ids according to inputs_embeds
# this is usually occurred in img2txt multimodal model when first enter into this forward function.
if input_ids is None and inputs_embeds is not None:
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])

if past_key_values is None:
past_key_values = tuple([None] * self.config.num_hidden_layers)

Expand Down Expand Up @@ -502,3 +519,122 @@ def set_state_dict(self, state_dict):
lm_head_weight = paddle.to_tensor(state_dict["lm_head.weight"], dtype=self.lm_head.weight.dtype)
self.lm_head.weight.set_value(lm_head_weight)
self.qwen.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


class QWenForQWenVLInferenceModel(QWenForCausalLMInferenceModel):
"""
This class is 99% like QWenForCausalLMInferenceModel.
Used only for QWenVL's second part.
"""

# This function corresponds to QWenVL's second part, only used for QWenVL.
@paddle.no_grad()
def generate_text_with_image_features(
self,
input_ids: paddle.Tensor,
image_features: paddle.Tensor,
img_pos: paddle.Tensor,
attention_mask: paddle.Tensor,
position_ids=None,
penalty_score=None,
frequency_score=None,
presence_score=None,
min_length=None,
max_length=None,
temperature=None,
top_p=None,
eos_token_id=None,
seq_len_encoder=None,
seq_len_decoder=None,
step_idx=None,
stop_flags=None,
tgt_ids=None,
tgt_pos=None,
tgt_generation_mask=None,
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**generate_kwargs
) -> paddle.Tensor:
inputs_embeds = self.qwen.wte(input_ids)
inputs_embeds_dtype = inputs_embeds.dtype
if inputs_embeds_dtype != paddle.float32:
inputs_embeds = paddle.cast(inputs_embeds, paddle.float32)
image_features = paddle.cast(image_features, paddle.float32)

for idx, (i, image_start_idx, image_end_idx) in enumerate(img_pos):
index = paddle.arange(image_start_idx + 1, image_end_idx).unsqueeze(-1)
inputs_embeds[i] = paddle.scatter(inputs_embeds[i], index, image_features[idx])

if inputs_embeds_dtype != paddle.float32:
inputs_embeds = paddle.cast(inputs_embeds, inputs_embeds_dtype)

outputs = self.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
seq_len_encoder=seq_len_encoder,
seq_len_decoder=seq_len_decoder,
step_idx=step_idx,
stop_flags=stop_flags,
tgt_ids=tgt_ids,
tgt_pos=tgt_pos,
tgt_generation_mask=tgt_generation_mask,
pre_ids=pre_ids,
stop_nums=stop_nums,
cache_kvs=cache_kvs,
)
return outputs

# rewrite to_static function in generation_utils.py
def to_static(self, output_path: str, config: dict):
dtype = config.get("dtype", paddle.get_default_dtype())
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), # input_ids
paddle.static.InputSpec(
shape=[None, None, None], dtype="float32", name="image_features"
), # image_features
paddle.static.InputSpec(shape=[None, 3], dtype="int64", name="img_pos"), # img_pos
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
paddle.static.InputSpec(
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
), # tgt_generation_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
[
paddle.static.InputSpec(
shape=shape,
dtype=dtype,
name="cache_kvs_{}".format(i),
)
for i, shape in enumerate(cache_kvs_shapes)
], # cache_kvs
]

model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
paddle.jit.save(model, output_path, skip_prune_program=True)
11 changes: 10 additions & 1 deletion paddlenlp/transformers/qwen/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __init__(
tensor_parallel_output=True,
no_bias=True,
tie_word_embeddings=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -72,4 +75,10 @@ def __init__(
self.use_fused_rope = use_fused_rope
self.no_bias = no_bias

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
107 changes: 107 additions & 0 deletions tests/llm/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import pytest
from parameterized import parameterized_class

from paddlenlp.experimental.transformers import QWenForQWenVLInferenceModel
from paddlenlp.transformers import ( # ChatGLMForCausalLM,
AutoConfig,
AutoTokenizer,
BloomForCausalLM,
ChatGLMForCausalLM,
Expand Down Expand Up @@ -325,3 +327,108 @@ def test_predictor(self):

self.assertGreaterEqual(full_match / len(result_0), 0.25)
self.assertGreaterEqual(count / len(result_0), 0.4)


class QWenVLTest(LLMTest, unittest.TestCase):
config_path: str = "./tests/fixtures/llm/predictor.yaml"
model_name_or_path: str = "__internal_testing__/tiny-fused-qwen"
model_class = QWenForCausalLM

def setUp(self) -> None:
super().setUp()
paddle.set_default_dtype("float32")
self.model_class.from_pretrained(self.model_name_or_path, dtype="float16").save_pretrained(self.output_dir)
AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir)

def test_forward(self):
self.disable_static()
config = AutoConfig.from_pretrained(self.output_dir)
config.quant_type = None
config.weight_only_quant_bits = None

paddle.set_default_dtype("float16")
model = QWenForQWenVLInferenceModel.from_pretrained(self.output_dir, config=config, dtype="float16")

batch = 1
seq = 31
max_len = 50
dtype = "float16"
input_ids = paddle.randint(0, 100, [batch, seq], dtype="int64")
image_features = paddle.randn([batch, 16, config.hidden_size], dtype="float16")
tgt_generation_mask = paddle.full([batch, 1, 1, max_len], 1, dtype=dtype)
img_pos = paddle.to_tensor([[0, 4, 21]], dtype="int64")
attention_mask = paddle.full([batch, 1, max_len, max_len], 0, dtype=dtype)
attention_mask[:, 0, :seq, :seq] = paddle.tril(paddle.ones(shape=(seq, seq), dtype=dtype))
position_ids = paddle.full([batch, seq], 0, dtype="int64")
for i in range(batch):
position_ids[i, :] = paddle.to_tensor([i for i in range(seq)], dtype="int64")

inputs = [
input_ids, # input_ids
image_features, # image_features
img_pos, # img_pos
attention_mask, # attention_mask
position_ids, # position_ids
paddle.full([batch, 1], 1.0, dtype="float32"), # penalty_score
paddle.full([batch, 1], 0.0, dtype="float32"), # frequency_score,
paddle.full([batch, 1], 0.0, dtype="float32"), # presence_score,
paddle.full([batch, 1], 1, dtype="int64"), # min_length,
paddle.full([batch, 1], max_len - seq, dtype="int64"), # max_length,
paddle.full([batch, 1], 1.0, dtype="float32"), # temperature,
paddle.full([batch, 1], 0.0, dtype="float32"), # top_p,
paddle.full([1], 151643, dtype="int64"), # eos_token_id,
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_encoder,
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_decoder,
paddle.full([batch, 1], 0, dtype="int64"), # step_idx,
paddle.full([batch, 1], False, dtype="bool"), # stop_flags,
paddle.full([batch, 1], -123, dtype="int64"), # tgt_ids can be be initialized arbitrarily
paddle.full([batch, 1], seq - 1, dtype="int64"), # tgt_pos,
tgt_generation_mask, # tgt_generation_mask,
paddle.full([batch, max_len], -100, dtype="int64"), # pre_ids, can be initialized arbitrarily
paddle.full([1], batch, dtype="int64"), # stop_nums, be batch
]
for i in range(config.num_hidden_layers):
tmp = paddle.rand(shape=[2, batch, 1, max_len, 64], dtype=dtype)
inputs.append(tmp)

model.eval()
model.generate_text_with_image_features(
input_ids=inputs[0],
image_features=inputs[1],
img_pos=inputs[2],
attention_mask=inputs[3],
position_ids=inputs[4],
penalty_score=inputs[5],
frequency_score=inputs[6],
presence_score=inputs[7],
min_length=inputs[8],
max_length=inputs[9],
temperature=inputs[10],
top_p=inputs[11],
eos_token_id=inputs[12],
seq_len_encoder=inputs[13],
seq_len_decoder=inputs[14],
step_idx=inputs[15],
stop_flags=inputs[16],
tgt_ids=inputs[17],
tgt_pos=inputs[18],
tgt_generation_mask=inputs[19],
pre_ids=inputs[20],
stop_nums=inputs[21],
cache_kvs=inputs[22:],
)

def test_export(self):
self.disable_static()
config = load_test_config(self.config_path, "inference-to-static")
config["model_name_or_path"] = self.model_name_or_path
config["output_path"] = self.output_dir
config["dtype"] = "float16"
config["inference_model"] = True
config["model_prefix"] = "qwen"
config["model_type"] = "qwen-img2txt"

with argv_context_guard(config):
from export_model import main

main()