-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[autonlp] text classification fix& add taskflow config file #4896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for your contribution! |
@@ -132,7 +133,7 @@ def _default_prompt_tuning_arguments(self) -> PromptTuningArguments: | |||
def _model_candidates(self) -> List[Dict[str, Any]]: | |||
train_batch_size = hp.choice("batch_size", [2, 4, 8, 16, 32]) | |||
chinese_models = hp.choice( | |||
"models", | |||
"chinese_models", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加一个还是不要4个不同的名字,因为英文和中文模型不会同时出现,就统一叫finetune_models和prompt_models吧。
同时这里增加一个单测(仿造model_candidates, 既有prompt model的hp.choice又有finetune model的hp.choice),避免以后类似的情况单测不能catch到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增单测test_default_model_candidate
Codecov Report
@@ Coverage Diff @@
## develop #4896 +/- ##
===========================================
+ Coverage 44.65% 46.36% +1.71%
===========================================
Files 446 448 +2
Lines 64375 64619 +244
===========================================
+ Hits 28744 29960 +1216
+ Misses 35631 34659 -972
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
@@ -132,7 +133,7 @@ def _default_prompt_tuning_arguments(self) -> PromptTuningArguments: | |||
def _model_candidates(self) -> List[Dict[str, Any]]: | |||
train_batch_size = hp.choice("batch_size", [2, 4, 8, 16, 32]) | |||
chinese_models = hp.choice( | |||
"models", | |||
"chinese_models", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加一个还是不要4个不同的名字,因为英文和中文模型不会同时出现,就统一叫finetune_models和prompt_models吧。
同时这里增加一个单测(仿造model_candidates, 既有prompt model的hp.choice又有finetune model的hp.choice),避免以后类似的情况单测不能catch到
@@ -538,6 +521,8 @@ def export(self, export_path, trial_id=None): | |||
if model_config["trainer_type"] == "PromptTrainer": | |||
trainer.export_model(export_path) | |||
trainer.model.plm.save_pretrained(os.path.join(export_path, "plm")) | |||
mode = "prompt" | |||
max_length = model_config.get("PreprocessArguments.max_length", 128) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要hardcode 128,试一下 trainer.model.config.max_position_embeddings
mode = "finetune" | ||
max_length = trainer.model.config.max_position_embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的逻辑和以上应该一致吧?model_config.get("PreprocessArguments.max_length", trainer.model.config.max_position_embeddings)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
统一修改为max_length = config.get("PreprocessArguments.max_length", model.config.max_position_embeddings)
} | ||
|
||
with open(os.path.join(export_path, "taskflow_config.json"), "w", encoding="utf-8") as f: | ||
json.dump(taskflow_config, f, ensure_ascii=False) | ||
|
||
if os.path.exists(self.training_path): | ||
logger.info("Removing training checkpoints to conserve disk space") | ||
shutil.rmtree(self.training_path) | ||
|
||
logger.info(f"Exported {trial_id} to {export_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以多log一行,taskflow config saved to {export_path}. You can use the taskflow config to create a Taskflow instance for inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加
@@ -298,16 +298,18 @@ def _construct_trainer(self, config, eval_dataset=None) -> Trainer: | |||
] | |||
else: | |||
callbacks = None | |||
max_length = config.get("PreprocessArguments.max_length", 128) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是不是多出来的?
@@ -551,6 +553,9 @@ def export(self, export_path, trial_id=None): | |||
|
|||
with open(os.path.join(export_path, "taskflow_config.json"), "w", encoding="utf-8") as f: | |||
json.dump(taskflow_config, f, ensure_ascii=False) | |||
logger.info( | |||
f"taskflow config saved to {export_path}. You can use the taskflow config to create a Taskflow instance for inference" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"taskflow config saved to {export_path}. You can use the taskflow config to create a Taskflow instance for inference" | |
f"Taskflow config saved to {export_path}. You can use the Taskflow config to create a Taskflow instance for inference" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
chinese_models = hp.choice( | ||
"models", | ||
chinese_finetune_models = hp.choice( | ||
"finetune_models", | ||
[ | ||
"ernie-1.0-large-zh-cw" # 24-layer, 1024-hidden, 16-heads, 272M parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
逗号? 还好单测发现了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -264,6 +263,96 @@ def test_multilabel(self, custom_model_candidate, hp_overrides): | |||
# test training_path | |||
self.assertFalse(os.path.exists(os.path.join(auto_trainer.training_path))) | |||
|
|||
@slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可能得挪到parameterized 下面才起效,Test里还是跑起来了,本地用pytest tests/experimental/autonlp/test_text_classification.py 验证一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本地验证确实要挪到下面
@@ -152,7 +152,6 @@ def _model_candidates(self) -> List[Dict[str, Any]]: | |||
"roberta-large", # 24-layer, 1024-hidden, 16-heads, 334M parameters. Case-sensitive | |||
"roberta-base", # 12-layer, 768-hidden, 12-heads, 110M parameters. Case-sensitive | |||
"distilroberta-base", # 6-layer, 768-hidden, 12-heads, 66M parameters. Case-sensitive | |||
"ernie-3.0-tiny-mini-v2-en", # 6-layer, 384-hidden, 12-heads, 27M parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试发现"ernie-3.0-tiny-mini-v2-en"加载tokenizer会报错,暂时删去
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
PR types
Bug fixes
PR changes
APIs
Description
修复&支持taskflow用config文件加载