Skip to content

Commit f8cdd3d

Browse files
authored
fix eval. (#8529)
* fix eval. * fix
1 parent 0087c4a commit f8cdd3d

File tree

1 file changed

+24
-33
lines changed
  • examples/benchmark/wiki_lambada

1 file changed

+24
-33
lines changed

examples/benchmark/wiki_lambada/eval.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def get_parser():
7373
default=False,
7474
help="Whether to use flash attention",
7575
)
76-
7776
# load autodist name files, eg: bloom-176b
7877
parser.add_argument("--load_autodist", action="store_true", help="whether load auto-dist wieght file")
7978

@@ -250,7 +249,8 @@ def get_tokens(tokenizer, text, strict=True):
250249
last_token = text.split()[-1]
251250
start_idx = text.rfind(last_token)
252251
beginning_tokens = tokenizer(text[:start_idx].strip())["input_ids"]
253-
last_token = tokenizer(" " + last_token)["input_ids"]
252+
all_tokens = tokenizer(text.strip())["input_ids"]
253+
last_token = all_tokens[len(beginning_tokens) :]
254254
return beginning_tokens, last_token
255255

256256

@@ -277,7 +277,7 @@ def create_eval_dataset(args):
277277
with open(args.eval_path, "r") as f:
278278
for line in f.readlines():
279279
text = json.loads(line)["text"]
280-
tokens, labels = get_tokens(tokenizer, text, strict=False)
280+
tokens, labels = get_tokens(tokenizer, text, strict=True)
281281
tokenized_data.append(tokens)
282282
tokenized_label.append(labels)
283283
val_dataset = Lambada_Eval_Dataset(tokenized_data, tokenized_label, seq_len, tokenizer.pad_token_id)
@@ -327,44 +327,35 @@ def do_generation():
327327
)
328328

329329
model.eval()
330-
args.use_pure_fp16 = False
331-
332330
total_score = 0
333331
score_name = "loss" if not args.cloze_eval else "number correct"
334-
args.use_pure_fp16 = False
335332
eval_data_loader = create_eval_dataset(args)
336333
with paddle.no_grad():
337334
for step, batch in enumerate(eval_data_loader):
338335

339336
tokens, loss_mask = batch[:2]
340337
labels = batch[-1]
341-
with paddle.amp.auto_cast(args.use_pure_fp16):
342-
if args.model_type == "bloom":
343-
preds = model(tokens).detach()
344-
else:
345-
preds = model(tokens)[0].detach()
346-
# print(preds)
347-
348-
# cast preds to float32 to keep high-precision
349-
preds = preds.astype(paddle.float32)
350-
351-
if not args.cloze_eval:
352-
masked_lm_loss = paddle.nn.functional.cross_entropy(preds, labels, reduction="none")
353-
loss = paddle.sum(masked_lm_loss * loss_mask)
354-
total_score += float(loss) / (args.num_tokenized_tokens - 1)
355-
else:
356-
outputs = paddle.argmax(preds, -1)
357-
acc = paddle.cast(outputs == labels, "float32")
358-
acc = paddle.where(paddle.cast(loss_mask, "bool"), acc, paddle.ones_like(acc))
359-
acc = paddle.sum(paddle.prod(acc, -1))
360-
total_score += float(acc)
361-
362-
if step % args.logging_steps == 0:
363-
logger.info(
364-
"step %d, batch: %d, %s: %f, speed: %.2f step/s"
365-
% (step, step, score_name, total_score, args.logging_steps / (time.time() - tic_eval))
366-
)
367-
tic_eval = time.time()
338+
preds = model(tokens, return_dict=True).logits.detach()
339+
# cast preds to float32 to keep high-precision
340+
preds = preds.astype(paddle.float32)
341+
342+
if not args.cloze_eval:
343+
masked_lm_loss = paddle.nn.functional.cross_entropy(preds, labels, reduction="none")
344+
loss = paddle.sum(masked_lm_loss * loss_mask)
345+
total_score += float(loss) / (args.num_tokenized_tokens - 1)
346+
else:
347+
outputs = paddle.argmax(preds, -1)
348+
acc = paddle.cast(outputs == labels, "float32")
349+
acc = paddle.where(paddle.cast(loss_mask, "bool"), acc, paddle.ones_like(acc))
350+
acc = paddle.sum(paddle.prod(acc, -1))
351+
total_score += float(acc)
352+
353+
if step % args.logging_steps == 0:
354+
logger.info(
355+
"step %d, batch: %d, %s: %f, speed: %.2f step/s"
356+
% (step, step, score_name, total_score, args.logging_steps / (time.time() - tic_eval))
357+
)
358+
tic_eval = time.time()
368359

369360
if not args.cloze_eval:
370361
total_loss = float(total_score)

0 commit comments

Comments
 (0)