Skip to content

[BugFix] Fix amp usage for evaluation. #3303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 24, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions model_zoo/ernie-1.0/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,33 +217,43 @@ def run_evaluate(data_loader,
for eval_step, batch in enumerate(data_loader):
input_ids, segment_ids, input_mask, masked_lm_positions, \
masked_lm_labels, next_sentence_labels = batch
with paddle.amp.auto_cast(args.use_amp,
custom_white_list=[
'softmax',
'layer_norm',
'gelu',
],
custom_black_list=[
"c_softmax_with_cross_entropy",
],
level=args.fp16_opt_level):

if args.binary_head:
prediction_scores, seq_relationship_score = model(
input_ids=input_ids,
token_type_ids=segment_ids,
position_ids=None,
attention_mask=input_mask,
masked_positions=masked_lm_positions)

lm_loss, sop_loss = criterion(prediction_scores,
seq_relationship_score,
masked_lm_labels,
next_sentence_labels)
loss = lm_loss + sop_loss
else:
prediction_scores = model(input_ids=input_ids,
token_type_ids=segment_ids,
position_ids=None,
attention_mask=input_mask,
masked_positions=masked_lm_positions)

loss = criterion(prediction_scores, None, masked_lm_labels)

loss_global["loss"] += loss.detach()
if args.binary_head:
loss_global["lm_loss"] += lm_loss.detach()
loss_global["sop_loss"] += sop_loss.detach()
if args.binary_head:
prediction_scores, seq_relationship_score = model(
input_ids=input_ids,
token_type_ids=segment_ids,
position_ids=None,
attention_mask=input_mask,
masked_positions=masked_lm_positions)

lm_loss, sop_loss = criterion(prediction_scores,
seq_relationship_score,
masked_lm_labels,
next_sentence_labels)
loss = lm_loss + sop_loss
else:
prediction_scores = model(input_ids=input_ids,
token_type_ids=segment_ids,
position_ids=None,
attention_mask=input_mask,
masked_positions=masked_lm_positions)

loss = criterion(prediction_scores, None, masked_lm_labels)

loss_global["loss"] += loss.detach()
if args.binary_head:
loss_global["lm_loss"] += lm_loss.detach()
loss_global["sop_loss"] += sop_loss.detach()

if eval_step >= iter_steps - 1:
log_info_dict = dict()
Expand Down