Skip to content

Commit c19d9aa

Browse files
authored
[AutoNLP]add visualdl (#4990)
* add visualdl
1 parent a345ca7 commit c19d9aa

File tree

4 files changed

+35
-2
lines changed

4 files changed

+35
-2
lines changed

paddlenlp/experimental/autonlp/auto_trainer_base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class AutoTrainerBase(metaclass=ABCMeta):
5353
export_path = "exported_model" # filepath for the exported static model
5454
results_filename = "experiment_results.csv" # filepath for storing experiment results
5555
experiment_path = None # filepath for the experiment results
56+
visualdl_path = "visualdl" # filepath for the visualdl
5657

5758
def __init__(
5859
self,
@@ -98,6 +99,14 @@ def _default_training_argument(self) -> TrainingArguments:
9899
"""
99100
Default TrainingArguments for the Trainer
100101
"""
102+
return TrainingArguments(
103+
output_dir=self.training_path,
104+
disable_tqdm=True,
105+
load_best_model_at_end=True,
106+
save_total_limit=1,
107+
report_to=["visualdl", "autonlp"],
108+
logging_dir=self.visualdl_path, # if logging_dir is redefined, the function visualdl() should be redefined as well.
109+
)
101110

102111
@property
103112
@abstractmethod
@@ -189,7 +198,10 @@ def _override_hp(self, config: Dict[str, Any], default_hp: Any) -> Any:
189198
new_hp = copy.deepcopy(default_hp)
190199
for key, value in config.items():
191200
if key in new_hp.to_dict():
192-
setattr(new_hp, key, value)
201+
if key in ["output_dir", "logging_dir"]:
202+
logger.warning(f"{key} cannot be overridden")
203+
else:
204+
setattr(new_hp, key, value)
193205
return new_hp
194206

195207
def _filter_model_candidates(
@@ -324,3 +336,10 @@ def train(
324336
)
325337

326338
return self.training_results
339+
340+
def visualdl(self, trial_id: Optional[str] = None):
341+
"""
342+
Return visualdl path to represent the results of the taskflow training.
343+
"""
344+
model_result = self._get_model_result(trial_id=trial_id)
345+
return os.path.join(model_result.log_dir, self.visualdl_path)

paddlenlp/experimental/autonlp/text_classification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def supported_languages(self) -> List[str]:
105105

106106
@property
107107
def _default_training_argument(self) -> TrainingArguments:
108+
"""
109+
Default TrainingArguments for the Trainer
110+
"""
108111
return TrainingArguments(
109112
output_dir=self.training_path,
110113
disable_tqdm=True,
@@ -115,6 +118,7 @@ def _default_training_argument(self) -> TrainingArguments:
115118
save_strategy="epoch",
116119
save_total_limit=1,
117120
report_to=["visualdl", "autonlp"],
121+
logging_dir=self.visualdl_path,
118122
)
119123

120124
@property
@@ -129,6 +133,7 @@ def _default_prompt_tuning_arguments(self) -> PromptTuningArguments:
129133
save_strategy="epoch",
130134
save_total_limit=1,
131135
report_to=["visualdl", "autonlp"],
136+
logging_dir=self.visualdl_path,
132137
)
133138

134139
@property

paddlenlp/trainer/integrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
112112
return
113113

114114
if self.vdl_writer is None:
115-
self._init_summary_writer(args)
115+
return
116116

117117
if self.vdl_writer is not None:
118118
logs = rewrite_logs(logs)

tests/experimental/autonlp/test_text_classification.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def test_multiclass(self, custom_model_candidate, hp_overrides):
150150
self.assertTrue(os.path.exists(os.path.join(save_path, "template_config.json")))
151151
self.assertTrue(os.path.exists(os.path.join(save_path, "verbalizer_config.json")))
152152

153+
# test visualdl
154+
self.assertTrue(os.path.isdir(auto_trainer.visualdl()))
155+
153156
# test evaluate
154157
copy_dev_ds = copy.deepcopy(self.multi_class_dev_ds)
155158
eval_metrics1 = auto_trainer.evaluate()
@@ -259,6 +262,9 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
259262
self.assertTrue(os.path.exists(os.path.join(save_path, "template_config.json")))
260263
self.assertTrue(os.path.exists(os.path.join(save_path, "verbalizer_config.json")))
261264

265+
# test visualdl
266+
self.assertTrue(os.path.isdir(auto_trainer.visualdl()))
267+
262268
# test evaluate
263269
copy_dev_ds = copy.deepcopy(self.multi_label_dev_ds)
264270
eval_metrics1 = auto_trainer.evaluate()
@@ -367,6 +373,9 @@ def test_default_model_candidate(self, language, hp_overrides):
367373
self.assertTrue(os.path.exists(os.path.join(save_path, "template_config.json")))
368374
self.assertTrue(os.path.exists(os.path.join(save_path, "verbalizer_config.json")))
369375

376+
# test visualdl
377+
self.assertTrue(os.path.isdir(auto_trainer.visualdl()))
378+
370379
# test evaluate
371380
copy_dev_ds = copy.deepcopy(self.multi_class_dev_ds)
372381
eval_metrics1 = auto_trainer.evaluate()

0 commit comments

Comments
 (0)