From caade9ef196374ba280c816151f2d90f9df7a7ec Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Mon, 19 Sep 2022 16:32:46 +0800 Subject: [PATCH 1/2] fix eval of amp usage. --- model_zoo/ernie-1.0/run_pretrain.py | 64 +++++++++++++++++------------ 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/model_zoo/ernie-1.0/run_pretrain.py b/model_zoo/ernie-1.0/run_pretrain.py index e6df62998e3a..400f2ab115bc 100644 --- a/model_zoo/ernie-1.0/run_pretrain.py +++ b/model_zoo/ernie-1.0/run_pretrain.py @@ -23,6 +23,7 @@ import time import yaml import shutil +from functools import partial import numpy as np import paddle @@ -199,6 +200,7 @@ def run_evaluate(data_loader, args, task_name="valid"): model.eval() + all_loss, all_lm_loss, all_sop_loss = [], [], [] if args.binary_head: @@ -217,33 +219,43 @@ def run_evaluate(data_loader, for eval_step, batch in enumerate(data_loader): input_ids, segment_ids, input_mask, masked_lm_positions, \ masked_lm_labels, next_sentence_labels = batch + with paddle.amp.auto_cast(args.use_amp, + custom_white_list=[ + 'softmax', + 'layer_norm', + 'gelu', + ], + custom_black_list=[ + "c_softmax_with_cross_entropy", + ], + level=args.fp16_opt_level): + + if args.binary_head: + prediction_scores, seq_relationship_score = model( + input_ids=input_ids, + token_type_ids=segment_ids, + position_ids=None, + attention_mask=input_mask, + masked_positions=masked_lm_positions) + + lm_loss, sop_loss = criterion(prediction_scores, + seq_relationship_score, + masked_lm_labels, + next_sentence_labels) + loss = lm_loss + sop_loss + else: + prediction_scores = model(input_ids=input_ids, + token_type_ids=segment_ids, + position_ids=None, + attention_mask=input_mask, + masked_positions=masked_lm_positions) + + loss = criterion(prediction_scores, None, masked_lm_labels) - if args.binary_head: - prediction_scores, seq_relationship_score = model( - input_ids=input_ids, - token_type_ids=segment_ids, - position_ids=None, - attention_mask=input_mask, - masked_positions=masked_lm_positions) - - lm_loss, sop_loss = criterion(prediction_scores, - seq_relationship_score, - masked_lm_labels, - next_sentence_labels) - loss = lm_loss + sop_loss - else: - prediction_scores = model(input_ids=input_ids, - token_type_ids=segment_ids, - position_ids=None, - attention_mask=input_mask, - masked_positions=masked_lm_positions) - - loss = criterion(prediction_scores, None, masked_lm_labels) - - loss_global["loss"] += loss.detach() - if args.binary_head: - loss_global["lm_loss"] += lm_loss.detach() - loss_global["sop_loss"] += sop_loss.detach() + loss_global["loss"] += loss.detach() + if args.binary_head: + loss_global["lm_loss"] += lm_loss.detach() + loss_global["sop_loss"] += sop_loss.detach() if eval_step >= iter_steps - 1: log_info_dict = dict() From fa97159529241b133307ea48ecf24c809e067943 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Mon, 19 Sep 2022 16:35:48 +0800 Subject: [PATCH 2/2] fix --- model_zoo/ernie-1.0/run_pretrain.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/model_zoo/ernie-1.0/run_pretrain.py b/model_zoo/ernie-1.0/run_pretrain.py index 400f2ab115bc..f46f09d7a2d7 100644 --- a/model_zoo/ernie-1.0/run_pretrain.py +++ b/model_zoo/ernie-1.0/run_pretrain.py @@ -23,7 +23,6 @@ import time import yaml import shutil -from functools import partial import numpy as np import paddle @@ -200,7 +199,6 @@ def run_evaluate(data_loader, args, task_name="valid"): model.eval() - all_loss, all_lm_loss, all_sop_loss = [], [], [] if args.binary_head: