Skip to content

Commit 3812224

Browse files
committed
fix
1 parent e054a1b commit 3812224

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

paddlenlp/experimental/autonlp/text_classification.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
135135
chinese_finetune_models = hp.choice(
136136
"finetune_models",
137137
[
138-
"ernie-1.0-large-zh-cw" # 24-layer, 1024-hidden, 16-heads, 272M parameters.
138+
"ernie-1.0-large-zh-cw", # 24-layer, 1024-hidden, 16-heads, 272M parameters.
139139
"ernie-3.0-xbase-zh", # 20-layer, 1024-hidden, 16-heads, 296M parameters.
140140
"ernie-3.0-tiny-base-v2-zh", # 12-layer, 768-hidden, 12-heads, 118M parameters.
141141
"ernie-3.0-tiny-medium-v2-zh", # 6-layer, 768-hidden, 12-heads, 75M parameters.
@@ -152,7 +152,6 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
152152
"roberta-large", # 24-layer, 1024-hidden, 16-heads, 334M parameters. Case-sensitive
153153
"roberta-base", # 12-layer, 768-hidden, 12-heads, 110M parameters. Case-sensitive
154154
"distilroberta-base", # 6-layer, 768-hidden, 12-heads, 66M parameters. Case-sensitive
155-
"ernie-3.0-tiny-mini-v2-en", # 6-layer, 384-hidden, 12-heads, 27M parameters
156155
"ernie-2.0-base-en", # 12-layer, 768-hidden, 12-heads, 103M parameters. Trained on lower-cased English text.
157156
"ernie-2.0-large-en", # 24-layer, 1024-hidden, 16-heads, 336M parameters. Trained on lower-cased English text.
158157
],
@@ -523,7 +522,9 @@ def export(self, export_path, trial_id=None):
523522
trainer.export_model(export_path)
524523
trainer.model.plm.save_pretrained(os.path.join(export_path, "plm"))
525524
mode = "prompt"
526-
max_length = model_config.get("PreprocessArguments.max_length", 128)
525+
max_length = model_config.get(
526+
"PreprocessArguments.max_length", trainer.model.plm.config.max_position_embeddings
527+
)
527528
else:
528529
if trainer.model.init_config["init_class"] in ["ErnieMForSequenceClassification"]:
529530
input_spec = [paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")]
@@ -534,7 +535,9 @@ def export(self, export_path, trial_id=None):
534535
]
535536
export_model(model=trainer.model, input_spec=input_spec, path=export_path)
536537
mode = "finetune"
537-
max_length = trainer.model.config.max_position_embeddings
538+
max_length = model_config.get(
539+
"PreprocessArguments.max_length", trainer.model.config.max_position_embeddings
540+
)
538541

539542
# save tokenizer
540543
trainer.tokenizer.save_pretrained(export_path)
@@ -553,7 +556,7 @@ def export(self, export_path, trial_id=None):
553556
with open(os.path.join(export_path, "taskflow_config.json"), "w", encoding="utf-8") as f:
554557
json.dump(taskflow_config, f, ensure_ascii=False)
555558
logger.info(
556-
f"taskflow config saved to {export_path}. You can use the taskflow config to create a Taskflow instance for inference"
559+
f"Taskflow config saved to {export_path}. You can use the Taskflow config to create a Taskflow instance for inference"
557560
)
558561

559562
if os.path.exists(self.training_path):

tests/experimental/autonlp/test_text_classification.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,18 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
263263
# test training_path
264264
self.assertFalse(os.path.exists(os.path.join(auto_trainer.training_path)))
265265

266-
@slow
267266
@parameterized.expand(
268267
[
269268
(
270-
None,
269+
"Chinese",
270+
{
271+
"TrainingArguments.max_steps": 2,
272+
"TrainingArguments.per_device_train_batch_size": 1,
273+
"TrainingArguments.per_device_eval_batch_size": 1,
274+
},
275+
),
276+
(
277+
"English",
271278
{
272279
"TrainingArguments.max_steps": 2,
273280
"TrainingArguments.per_device_train_batch_size": 1,
@@ -276,7 +283,8 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
276283
),
277284
]
278285
)
279-
def test_default_model_candidate(self, custom_model_candidate, hp_overrides):
286+
@slow
287+
def test_default_model_candidate(self, language, hp_overrides):
280288
with TemporaryDirectory() as temp_dir_path:
281289
train_ds = copy.deepcopy(self.multi_class_train_ds)
282290
dev_ds = copy.deepcopy(self.multi_class_dev_ds)
@@ -287,7 +295,7 @@ def test_default_model_candidate(self, custom_model_candidate, hp_overrides):
287295
eval_dataset=dev_ds,
288296
label_column="label_desc",
289297
text_column="sentence",
290-
language="Chinese",
298+
language=language,
291299
output_dir=temp_dir_path,
292300
problem_type="multi_class",
293301
)
@@ -296,7 +304,6 @@ def test_default_model_candidate(self, custom_model_candidate, hp_overrides):
296304
num_gpus=1,
297305
max_concurrent_trials=1,
298306
num_models=num_models,
299-
custom_model_candidates=custom_model_candidate,
300307
hp_overrides=hp_overrides,
301308
)
302309

0 commit comments

Comments
 (0)