Skip to content

Commit 00d842b

Browse files
authored
Fix taskflow custom_model download issue (#4984)
* changes * fix unit test * hardcode plato * hardcode ernievil * fix text similarity
1 parent 1c01e6f commit 00d842b

File tree

7 files changed

+92
-40
lines changed

7 files changed

+92
-40
lines changed

paddlenlp/taskflow/dialogue.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,33 @@ class DialogueTask(Task):
7474
"https://bj.bcebos.com/paddlenlp/taskflow/dialogue/plato-mini/model_config.json",
7575
"5e853fda9a9b573815ad112e494a65af",
7676
],
77-
}
77+
},
78+
"__internal_testing__/tiny-random-plato": {
79+
"model_state": [
80+
"https://bj.bcebos.com/paddlenlp/models/community/__internal_testing__/tiny-random-plato/model_state.pdparams",
81+
"fda5d068908505cf0c3a46125eb4d39e",
82+
],
83+
"model_config": [
84+
"https://bj.bcebos.com/paddlenlp/models/community/__internal_testing__/tiny-random-plato/config.json",
85+
"3664e658d5273a132f2e7345a8cafa53",
86+
],
87+
},
7888
}
7989

8090
def __init__(self, task, model, batch_size=1, max_seq_len=512, **kwargs):
8191
super().__init__(task=task, model=model, **kwargs)
8292
self._static_mode = False
8393
self._usage = usage
84-
if not self.from_hf_hub:
94+
if not self._custom_model:
8595
self._check_task_files()
86-
self._construct_tokenizer(self._task_path if self.from_hf_hub else model)
96+
self._construct_tokenizer(self._task_path if self._custom_model else model)
8797
self._batch_size = batch_size
8898
self._max_seq_len = max_seq_len
8999
self._interactive_mode = False
90100
if self._static_mode:
91101
self._get_inference_model()
92102
else:
93-
self._construct_model(self._task_path if self.from_hf_hub else model)
103+
self._construct_model(self._task_path if self._custom_model else model)
94104

95105
def _construct_input_spec(self):
96106
"""

paddlenlp/taskflow/feature_extraction.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,19 +180,44 @@ class MultimodalFeatureExtractionTask(Task):
180180
"573ba0466e15cdb5bd423ff7010735ce",
181181
],
182182
},
183+
"__internal_testing__/tiny-random-ernievil2": {
184+
"model_state": [
185+
"https://paddlenlp.bj.bcebos.com/models/community/__internal_testing__/tiny-random-ernievil2/model_state.pdparams",
186+
"771c844e7b75f61123d9606c8c17b1d6",
187+
],
188+
"config": [
189+
"https://paddlenlp.bj.bcebos.com/models/community/__internal_testing__/tiny-random-ernievil2/config.json",
190+
"ae27a68336ccec6d3ffd14b48a6d1f25",
191+
],
192+
"vocab_file": [
193+
"https://paddlenlp.bj.bcebos.com/models/community/__internal_testing__/tiny-random-ernievil2/vocab.txt",
194+
"1c1c1f4fd93c5bed3b4eebec4de976a8",
195+
],
196+
"preprocessor_config": [
197+
"https://paddlenlp.bj.bcebos.com/models/community/__internal_testing__/tiny-random-ernievil2/preprocessor_config.json",
198+
"9a2e8da9f41896fedb86756b79355ee2",
199+
],
200+
"special_tokens_map": [
201+
"https://paddlenlp.bj.bcebos.com/models/community/__internal_testing__/tiny-random-ernievil2/special_tokens_map.json",
202+
"8b3fb1023167bb4ab9d70708eb05f6ec",
203+
],
204+
"tokenizer_config": [
205+
"https://paddlenlp.bj.bcebos.com/models/community/__internal_testing__/tiny-random-ernievil2/tokenizer_config.json",
206+
"2333f189cad8dd559de61bbff4d4a789",
207+
],
208+
},
183209
}
184210

185-
def __init__(
186-
self, task, model, batch_size=1, is_static_model=True, max_seq_len=128, return_tensors="pd", **kwargs
187-
):
211+
def __init__(self, task, model, batch_size=1, is_static_model=True, max_length=128, return_tensors="pd", **kwargs):
188212
super().__init__(task=task, model=model, **kwargs)
189213
self._seed = None
190214
# we do not use batch
191215
self.export_type = "text"
192216
self._batch_size = batch_size
193217
self.return_tensors = return_tensors
194-
self._max_seq_len = max_seq_len
195-
self._check_task_files()
218+
if not self.from_hf_hub:
219+
self._check_task_files()
220+
self._max_length = max_length
196221
self._construct_tokenizer()
197222
self.is_static_model = is_static_model
198223
self._config_map = {}
@@ -217,7 +242,7 @@ def _construct_tokenizer(self):
217242
"""
218243
Construct the tokenizer for the predictor.
219244
"""
220-
self._processor = AutoProcessor.from_pretrained(self.model)
245+
self._processor = AutoProcessor.from_pretrained(self._task_path)
221246

222247
def _batchify(self, data, batch_size):
223248
"""
@@ -238,7 +263,7 @@ def _parse_batch(batch_examples):
238263
images=batch_images,
239264
return_tensors="np",
240265
padding="max_length",
241-
max_seq_len=self._max_seq_len,
266+
max_length=self._max_length,
242267
truncation=True,
243268
)
244269
else:
@@ -248,7 +273,7 @@ def _parse_batch(batch_examples):
248273
images=batch_images,
249274
return_tensors="pd",
250275
padding="max_length",
251-
max_seq_len=self._max_seq_len,
276+
max_length=self._max_length,
252277
truncation=True,
253278
)
254279
return tokenized_inputs

paddlenlp/taskflow/taskflow.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@
7171
"dialogue": {
7272
"models": {
7373
"plato-mini": {"task_class": DialogueTask, "task_flag": "dialogue-plato-mini"},
74+
"__internal_testing__/tiny-random-plato": {
75+
"task_class": DialogueTask,
76+
"task_flag": "dialogue-tiny-random-plato",
77+
},
7478
},
7579
"default": {
7680
"model": "plato-mini",
@@ -235,6 +239,10 @@
235239
"task_class": TextSimilarityTask,
236240
"task_flag": "text_similarity-rocketqa-nano-cross-encoder",
237241
},
242+
"__internal_testing__/tiny-random-bert": {
243+
"task_class": TextSimilarityTask,
244+
"task_flag": "text_similarity-tiny-random-bert",
245+
},
238246
},
239247
"default": {"model": "simbert-base-chinese"},
240248
},
@@ -597,6 +605,11 @@
597605
"task_flag": "feature_extraction-openai/clip-rn50x4",
598606
"task_priority_path": "openai/clip-rn50x4",
599607
},
608+
"__internal_testing__/tiny-random-ernievil2": {
609+
"task_class": MultimodalFeatureExtractionTask,
610+
"task_flag": "feature_extraction-tiny-random-ernievil2",
611+
"task_priority_path": "__internal_testing__/tiny-random-ernievil2",
612+
},
600613
},
601614
"default": {"model": "PaddlePaddle/ernie_vil-2.0-base-zh"},
602615
},

paddlenlp/taskflow/text_similarity.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,30 @@ class TextSimilarityTask(Task):
122122
"dcff14cd671e1064be2c5d63734098bb",
123123
],
124124
},
125+
"__internal_testing__/tiny-random-bert": {
126+
"model_state": [
127+
"https://bj.bcebos.com/paddlenlp/models/community/__internal_testing__/tiny-random-bert/model_state.pdparams",
128+
"a7a54deee08235fc6ae454f5def2d663",
129+
],
130+
"model_config": [
131+
"https://bj.bcebos.com/paddlenlp/models/community/__internal_testing__/tiny-random-bert/config.json",
132+
"bfaa763f77da7cc796de4e0ad4b389e9",
133+
],
134+
},
125135
}
126136

127-
def __init__(self, task, model, batch_size=1, max_seq_len=384, **kwargs):
137+
def __init__(self, task, model, batch_size=1, max_length=384, **kwargs):
128138
super().__init__(task=task, model=model, **kwargs)
129139
self._static_mode = True
130-
self._check_task_files()
131-
self._get_inference_model()
140+
if not self.from_hf_hub:
141+
self._check_task_files()
132142
if self._static_mode:
133143
self._get_inference_model()
134144
else:
135145
self._construct_model(model)
136146
self._construct_tokenizer(model)
137147
self._batch_size = batch_size
138-
self._max_seq_len = max_seq_len
148+
self._max_length = max_length
139149
self._usage = usage
140150
self.model_name = model
141151

@@ -185,16 +195,16 @@ def _preprocess(self, inputs):
185195
for data in inputs:
186196
text1, text2 = data[0], data[1]
187197
if "rocketqa" in self.model_name:
188-
encoded_inputs = self._tokenizer(text=text1, text_pair=text2, max_seq_len=self._max_seq_len)
198+
encoded_inputs = self._tokenizer(text=text1, text_pair=text2, max_length=self._max_length)
189199
ids = encoded_inputs["input_ids"]
190200
segment_ids = encoded_inputs["token_type_ids"]
191201
examples.append((ids, segment_ids))
192202
else:
193-
text1_encoded_inputs = self._tokenizer(text=text1, max_seq_len=self._max_seq_len)
203+
text1_encoded_inputs = self._tokenizer(text=text1, max_length=self._max_length)
194204
text1_input_ids = text1_encoded_inputs["input_ids"]
195205
text1_token_type_ids = text1_encoded_inputs["token_type_ids"]
196206

197-
text2_encoded_inputs = self._tokenizer(text=text2, max_seq_len=self._max_seq_len)
207+
text2_encoded_inputs = self._tokenizer(text=text2, max_length=self._max_length)
198208
text2_input_ids = text2_encoded_inputs["input_ids"]
199209
text2_token_type_ids = text2_encoded_inputs["token_type_ids"]
200210

tests/taskflow/test_dialogue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class TestDialogueTask(unittest.TestCase):
2222
def setUpClass(cls):
2323
cls.dialogue = Taskflow(
2424
task="dialogue",
25-
task_path="__internal_testing__/tiny-random-plato",
25+
model="__internal_testing__/tiny-random-plato",
2626
)
2727
cls.max_turn = 3
2828

tests/taskflow/test_feature_extraction.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,28 @@ def setUpClass(cls):
3232
cls.max_resolution = 40
3333
cls.min_resolution = 30
3434
cls.num_channels = 3
35-
cls.max_seq_len = 30
36-
cls.model = "__internal_testing__/tiny-random-ernievil2"
35+
cls.max_length = 30
3736

3837
@classmethod
3938
def tearDownClass(cls):
4039
cls.temp_dir.cleanup()
4140

4241
def test_model_np(self):
4342
feature_extractor = Taskflow(
44-
model="PaddlePaddle/ernie_vil-2.0-base-zh",
43+
model="__internal_testing__/tiny-random-ernievil2",
4544
task="feature_extraction",
46-
task_path=self.model,
4745
return_tensors="np",
48-
max_seq_len=self.max_seq_len,
46+
max_length=self.max_length,
4947
)
5048
outputs = feature_extractor("This is a test")
5149
self.assertEqual(outputs["features"].shape, (1, 32))
5250

5351
def test_return_tensors(self):
5452
feature_extractor = Taskflow(
55-
model="PaddlePaddle/ernie_vil-2.0-base-zh",
53+
model="__internal_testing__/tiny-random-ernievil2",
5654
task="feature_extraction",
57-
task_path=self.model,
5855
return_tensors="pd",
59-
max_seq_len=self.max_seq_len,
56+
max_length=self.max_length,
6057
)
6158
outputs = feature_extractor(
6259
"This is a test",
@@ -97,25 +94,23 @@ def test_feature_extraction_task(self):
9794
input_text = (["这是一只猫", "这是一只狗"],)
9895
# dygraph text test
9996
dygraph_taskflow = MultimodalFeatureExtractionTask(
100-
model="PaddlePaddle/ernie_vil-2.0-base-zh",
97+
model="__internal_testing__/tiny-random-ernievil2",
10198
task="feature_extraction",
102-
task_path=self.model,
10399
is_static_model=False,
104100
return_tensors="np",
105-
max_seq_len=self.max_seq_len,
101+
max_length=self.max_length,
106102
)
107103
dygraph_results = dygraph_taskflow(input_text)
108104
shape = dygraph_results["features"].shape
109105
self.assertEqual(shape[0], 2)
110106
# static text test
111107
static_taskflow = MultimodalFeatureExtractionTask(
112-
model="PaddlePaddle/ernie_vil-2.0-base-zh",
108+
model="__internal_testing__/tiny-random-ernievil2",
113109
task="feature_extraction",
114-
task_path=self.model,
115110
is_static_model=True,
116111
return_tensors="np",
117112
device_id=0,
118-
max_seq_len=self.max_seq_len,
113+
max_length=self.max_length,
119114
)
120115
static_results = static_taskflow(input_text)
121116
self.assertEqual(static_results["features"].shape[0], 2)
@@ -142,23 +137,23 @@ def test_taskflow_task(self):
142137

143138
# dygraph test
144139
dygraph_taskflow = Taskflow(
140+
model="__internal_testing__/tiny-random-ernievil2",
145141
task="feature_extraction",
146-
task_path=self.model,
147142
is_static_model=False,
148143
return_tensors="np",
149-
max_seq_len=self.max_seq_len,
144+
max_length=self.max_length,
150145
)
151146
dygraph_results = dygraph_taskflow(input_text)
152147
shape = dygraph_results["features"].shape
153148

154149
self.assertEqual(shape[0], 2)
155150
# static test
156151
static_taskflow = Taskflow(
152+
model="__internal_testing__/tiny-random-ernievil2",
157153
task="feature_extraction",
158-
task_path=self.model,
159154
is_static_model=True,
160155
return_tensors="np",
161-
max_seq_len=self.max_seq_len,
156+
max_length=self.max_length,
162157
)
163158
static_results = static_taskflow(input_text)
164159
self.assertEqual(static_results["features"].shape[0], 2)

tests/taskflow/test_text_similarity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ class TestTextSimilarityTask(unittest.TestCase):
2121
def test_bert_model(self):
2222
similarity = Taskflow(
2323
task="text_similarity",
24-
model="simbert-base-chinese",
25-
task_path="__internal_testing__/tiny-random-bert",
24+
model="__internal_testing__/tiny-random-bert",
2625
)
2726
results = similarity([["世界上什么东西最小", "世界上什么东西最小?"]])
2827
self.assertTrue(len(results) == 1)

0 commit comments

Comments
 (0)