Skip to content

Commit ab5388d

Browse files
committed
add qwenvl test
1 parent d7023f9 commit ab5388d

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

tests/llm/test_qwenvl.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import unittest
17+
18+
import paddle
19+
20+
from paddlenlp.transformers import (
21+
AutoConfig,
22+
AutoTokenizer,
23+
QWenForCausalLM
24+
)
25+
26+
from paddlenlp.experimental.transformers import (
27+
QWenForQWenVLInferenceModel
28+
)
29+
30+
from .testing_utils import LLMTest, argv_context_guard, load_test_config
31+
32+
paddle.seed(1234)
33+
34+
class QWenVLTest(LLMTest, unittest.TestCase):
35+
config_path: str = "./tests/fixtures/llm/predictor.yaml"
36+
model_name_or_path: str = "__internal_testing__/tiny-fused-qwen"
37+
model_class = QWenForCausalLM
38+
39+
def setUp(self) -> None:
40+
super().setUp()
41+
paddle.set_default_dtype("float32")
42+
self.model_class.from_pretrained(self.model_name_or_path, dtype="float16").save_pretrained(self.output_dir)
43+
AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir)
44+
45+
def test_forward(self):
46+
self.disable_static()
47+
config = AutoConfig.from_pretrained(self.output_dir)
48+
config["quant_type"] = None
49+
config["weight_only_quant_bits"] = None
50+
51+
paddle.set_default_dtype("float16")
52+
model = QWenForQWenVLInferenceModel.from_pretrained(
53+
self.output_dir,
54+
config=config,
55+
dtype="float16"
56+
)
57+
58+
batch = 1
59+
seq = 271
60+
max_len = 1024
61+
dtype = "float16"
62+
input_ids = paddle.randint(0, 100, [batch, seq], dtype="int64")
63+
image_features = paddle.randn([batch, 256, 4096], dtype="float16")
64+
tgt_generation_mask = paddle.full([batch, 1, 1, max_len], 1, dtype=dtype)
65+
img_pos = paddle.to_tensor([0, 4, 261], dtype="int64")
66+
attention_mask = paddle.full([batch, 1, max_len, max_len], 0, dtype=dtype)
67+
attention_mask[:,0,:seq,:seq] = paddle.tril(paddle.ones(shape=(seq, seq), dtype=dtype))
68+
position_ids = paddle.full([batch, seq], 0, dtype="int64")
69+
for i in range(batch):
70+
position_ids[i,:] = paddle.to_tensor([i for i in range(seq)], dtype="int64")
71+
72+
inputs = [
73+
input_ids, # input_ids
74+
image_features, # image_features
75+
img_pos, # img_pos
76+
attention_mask, # attention_mask
77+
position_ids, # position_ids
78+
paddle.full([batch, 1], 1.0, dtype="float32"), # penalty_score
79+
paddle.full([batch, 1], 0.0, dtype="float32"), # frequency_score,
80+
paddle.full([batch, 1], 0.0, dtype="float32"), # presence_score,
81+
paddle.full([batch, 1], 1, dtype="int64"), # min_length,
82+
paddle.full([batch, 1], max_len - seq, dtype="int64"), # max_length,
83+
paddle.full([batch, 1], 1.0, dtype="float32"), # temperature,
84+
paddle.full([batch, 1], 0.0, dtype="float32"), # top_p,
85+
paddle.full([1], 151643, dtype="int64"), # eos_token_id,
86+
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_encoder,
87+
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_decoder,
88+
paddle.full([batch, 1], 0, dtype="int64"), # step_idx,
89+
paddle.full([batch, 1], False, dtype="bool"), # stop_flags,
90+
paddle.full([batch, 1], -123, dtype="int64"), # tgt_ids can be be initialized arbitrarily
91+
paddle.full([batch, 1], seq - 1, dtype="int64"), # tgt_pos,
92+
tgt_generation_mask, # tgt_generation_mask,
93+
paddle.full([batch, max_len], -100, dtype="int64"), # pre_ids, can be initialized arbitrarily
94+
paddle.full([1], batch, dtype="int64") # stop_nums, be batch
95+
]
96+
for i in range(config.num_hidden_layers):
97+
tmp = paddle.rand(shape=[2, batch, 32, max_len, 128], dtype=dtype)
98+
inputs.append(tmp)
99+
100+
model.eval()
101+
model.generate_text_with_image_features(
102+
input_ids=inputs[0],
103+
image_features=inputs[1],
104+
img_pos=inputs[2],
105+
attention_mask=inputs[3],
106+
position_ids=inputs[4],
107+
penalty_score=inputs[5],
108+
frequency_score=inputs[6],
109+
presence_score=inputs[7],
110+
min_length=inputs[8],
111+
max_length=inputs[9],
112+
temperature=inputs[10],
113+
top_p=inputs[11],
114+
eos_token_id=inputs[12],
115+
seq_len_encoder=inputs[13],
116+
seq_len_decoder=inputs[14],
117+
step_idx=inputs[15],
118+
stop_flags=inputs[16],
119+
tgt_ids=inputs[17],
120+
tgt_pos=inputs[18],
121+
tgt_generation_mask=inputs[19],
122+
pre_ids=inputs[20],
123+
stop_nums=inputs[21],
124+
cache_kvs=inputs[22:],
125+
)
126+
127+
def test_export(self):
128+
self.disable_static()
129+
config = load_test_config(self.config_path, "inference-to-static")
130+
config["model_name_or_path"] = self.model_name_or_path
131+
config["output_path"] = self.output_dir
132+
config["dtype"] = "float16"
133+
config["inference_model"] = True
134+
config["model_prefix"] = "qwen"
135+
config["model_type"] = "qwen-img2txt"
136+
137+
with argv_context_guard(config):
138+
from export_model import main
139+
140+
main()

0 commit comments

Comments
 (0)