Skip to content

Commit f7357ba

Browse files
author
tianyu.zhou
committed
Cherry pick PR PaddlePaddle#8529.
1 parent 428c762 commit f7357ba

File tree

1 file changed

+34
-37
lines changed
  • examples/benchmark/wiki_lambada

1 file changed

+34
-37
lines changed

examples/benchmark/wiki_lambada/eval.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def get_parser():
5757
"--device",
5858
type=str,
5959
default="gpu",
60-
choices=["cpu", "eval_pathgpu", "xpu", "npu"],
61-
help="select cpu, gpu, xpu devices.",
60+
choices=["cpu", "gpu", "xpu", "npu", "gcu"],
61+
help="select cpu, gpu, xpu, gcu devices.",
6262
)
6363
parser.add_argument(
6464
"--dtype",
@@ -67,7 +67,12 @@ def get_parser():
6767
choices=["bfloat16", "float16", "float32"],
6868
help="set the dtype of model",
6969
)
70-
70+
parser.add_argument(
71+
"--use_flash_attention",
72+
type=bool,
73+
default=False,
74+
help="Whether to use flash attention",
75+
)
7176
# load autodist name files, eg: bloom-176b
7277
parser.add_argument("--load_autodist", action="store_true", help="whether load auto-dist wieght file")
7378

@@ -244,7 +249,8 @@ def get_tokens(tokenizer, text, strict=True):
244249
last_token = text.split()[-1]
245250
start_idx = text.rfind(last_token)
246251
beginning_tokens = tokenizer(text[:start_idx].strip())["input_ids"]
247-
last_token = tokenizer(" " + last_token)["input_ids"]
252+
all_tokens = tokenizer(text.strip())["input_ids"]
253+
last_token = all_tokens[len(beginning_tokens) :]
248254
return beginning_tokens, last_token
249255

250256

@@ -271,7 +277,7 @@ def create_eval_dataset(args):
271277
with open(args.eval_path, "r") as f:
272278
for line in f.readlines():
273279
text = json.loads(line)["text"]
274-
tokens, labels = get_tokens(tokenizer, text, strict=False)
280+
tokens, labels = get_tokens(tokenizer, text, strict=True)
275281
tokenized_data.append(tokens)
276282
tokenized_label.append(labels)
277283
val_dataset = Lambada_Eval_Dataset(tokenized_data, tokenized_label, seq_len, tokenizer.pad_token_id)
@@ -316,49 +322,40 @@ def do_generation():
316322
tensor_parallel_output=False,
317323
tensor_parallel_degree=args.tensor_parallel_degree,
318324
tensor_parallel_rank=paddle.distributed.get_rank(),
319-
use_flash_attention=False,
325+
use_flash_attention=args.use_flash_attention,
320326
dtype=args.dtype, # todo enable set dtype to avoid additional mem usage
321327
)
322328

323329
model.eval()
324-
args.use_pure_fp16 = False
325-
326330
total_score = 0
327331
score_name = "loss" if not args.cloze_eval else "number correct"
328-
args.use_pure_fp16 = False
329332
eval_data_loader = create_eval_dataset(args)
330333
with paddle.no_grad():
331334
for step, batch in enumerate(eval_data_loader):
332335

333336
tokens, loss_mask = batch[:2]
334337
labels = batch[-1]
335-
with paddle.amp.auto_cast(args.use_pure_fp16):
336-
if args.model_type == "bloom":
337-
preds = model(tokens).detach()
338-
else:
339-
preds = model(tokens)[0].detach()
340-
# print(preds)
341-
342-
# cast preds to float32 to keep high-precision
343-
preds = preds.astype(paddle.float32)
344-
345-
if not args.cloze_eval:
346-
masked_lm_loss = paddle.nn.functional.cross_entropy(preds, labels, reduction="none")
347-
loss = paddle.sum(masked_lm_loss * loss_mask)
348-
total_score += float(loss) / (args.num_tokenized_tokens - 1)
349-
else:
350-
outputs = paddle.argmax(preds, -1)
351-
acc = paddle.cast(outputs == labels, "float32")
352-
acc = paddle.where(paddle.cast(loss_mask, "bool"), acc, paddle.ones_like(acc))
353-
acc = paddle.sum(paddle.prod(acc, -1))
354-
total_score += float(acc)
355-
356-
if step % args.logging_steps == 0:
357-
logger.info(
358-
"step %d, batch: %d, %s: %f, speed: %.2f step/s"
359-
% (step, step, score_name, total_score, args.logging_steps / (time.time() - tic_eval))
360-
)
361-
tic_eval = time.time()
338+
preds = model(tokens, return_dict=True).logits.detach()
339+
# cast preds to float32 to keep high-precision
340+
preds = preds.astype(paddle.float32)
341+
342+
if not args.cloze_eval:
343+
masked_lm_loss = paddle.nn.functional.cross_entropy(preds, labels, reduction="none")
344+
loss = paddle.sum(masked_lm_loss * loss_mask)
345+
total_score += float(loss) / (args.num_tokenized_tokens - 1)
346+
else:
347+
outputs = paddle.argmax(preds, -1)
348+
acc = paddle.cast(outputs == labels, "float32")
349+
acc = paddle.where(paddle.cast(loss_mask, "bool"), acc, paddle.ones_like(acc))
350+
acc = paddle.sum(paddle.prod(acc, -1))
351+
total_score += float(acc)
352+
353+
if step % args.logging_steps == 0:
354+
logger.info(
355+
"step %d, batch: %d, %s: %f, speed: %.2f step/s"
356+
% (step, step, score_name, total_score, args.logging_steps / (time.time() - tic_eval))
357+
)
358+
tic_eval = time.time()
362359

363360
if not args.cloze_eval:
364361
total_loss = float(total_score)
@@ -381,4 +378,4 @@ def do_generation():
381378

382379

383380
if __name__ == "__main__":
384-
do_generation()
381+
do_generation()

0 commit comments

Comments
 (0)