@@ -73,7 +73,6 @@ def get_parser():
73
73
default = False ,
74
74
help = "Whether to use flash attention" ,
75
75
)
76
-
77
76
# load autodist name files, eg: bloom-176b
78
77
parser .add_argument ("--load_autodist" , action = "store_true" , help = "whether load auto-dist wieght file" )
79
78
@@ -250,7 +249,8 @@ def get_tokens(tokenizer, text, strict=True):
250
249
last_token = text .split ()[- 1 ]
251
250
start_idx = text .rfind (last_token )
252
251
beginning_tokens = tokenizer (text [:start_idx ].strip ())["input_ids" ]
253
- last_token = tokenizer (" " + last_token )["input_ids" ]
252
+ all_tokens = tokenizer (text .strip ())["input_ids" ]
253
+ last_token = all_tokens [len (beginning_tokens ) :]
254
254
return beginning_tokens , last_token
255
255
256
256
@@ -277,7 +277,7 @@ def create_eval_dataset(args):
277
277
with open (args .eval_path , "r" ) as f :
278
278
for line in f .readlines ():
279
279
text = json .loads (line )["text" ]
280
- tokens , labels = get_tokens (tokenizer , text , strict = False )
280
+ tokens , labels = get_tokens (tokenizer , text , strict = True )
281
281
tokenized_data .append (tokens )
282
282
tokenized_label .append (labels )
283
283
val_dataset = Lambada_Eval_Dataset (tokenized_data , tokenized_label , seq_len , tokenizer .pad_token_id )
@@ -327,44 +327,35 @@ def do_generation():
327
327
)
328
328
329
329
model .eval ()
330
- args .use_pure_fp16 = False
331
-
332
330
total_score = 0
333
331
score_name = "loss" if not args .cloze_eval else "number correct"
334
- args .use_pure_fp16 = False
335
332
eval_data_loader = create_eval_dataset (args )
336
333
with paddle .no_grad ():
337
334
for step , batch in enumerate (eval_data_loader ):
338
335
339
336
tokens , loss_mask = batch [:2 ]
340
337
labels = batch [- 1 ]
341
- with paddle .amp .auto_cast (args .use_pure_fp16 ):
342
- if args .model_type == "bloom" :
343
- preds = model (tokens ).detach ()
344
- else :
345
- preds = model (tokens )[0 ].detach ()
346
- # print(preds)
347
-
348
- # cast preds to float32 to keep high-precision
349
- preds = preds .astype (paddle .float32 )
350
-
351
- if not args .cloze_eval :
352
- masked_lm_loss = paddle .nn .functional .cross_entropy (preds , labels , reduction = "none" )
353
- loss = paddle .sum (masked_lm_loss * loss_mask )
354
- total_score += float (loss ) / (args .num_tokenized_tokens - 1 )
355
- else :
356
- outputs = paddle .argmax (preds , - 1 )
357
- acc = paddle .cast (outputs == labels , "float32" )
358
- acc = paddle .where (paddle .cast (loss_mask , "bool" ), acc , paddle .ones_like (acc ))
359
- acc = paddle .sum (paddle .prod (acc , - 1 ))
360
- total_score += float (acc )
361
-
362
- if step % args .logging_steps == 0 :
363
- logger .info (
364
- "step %d, batch: %d, %s: %f, speed: %.2f step/s"
365
- % (step , step , score_name , total_score , args .logging_steps / (time .time () - tic_eval ))
366
- )
367
- tic_eval = time .time ()
338
+ preds = model (tokens , return_dict = True ).logits .detach ()
339
+ # cast preds to float32 to keep high-precision
340
+ preds = preds .astype (paddle .float32 )
341
+
342
+ if not args .cloze_eval :
343
+ masked_lm_loss = paddle .nn .functional .cross_entropy (preds , labels , reduction = "none" )
344
+ loss = paddle .sum (masked_lm_loss * loss_mask )
345
+ total_score += float (loss ) / (args .num_tokenized_tokens - 1 )
346
+ else :
347
+ outputs = paddle .argmax (preds , - 1 )
348
+ acc = paddle .cast (outputs == labels , "float32" )
349
+ acc = paddle .where (paddle .cast (loss_mask , "bool" ), acc , paddle .ones_like (acc ))
350
+ acc = paddle .sum (paddle .prod (acc , - 1 ))
351
+ total_score += float (acc )
352
+
353
+ if step % args .logging_steps == 0 :
354
+ logger .info (
355
+ "step %d, batch: %d, %s: %f, speed: %.2f step/s"
356
+ % (step , step , score_name , total_score , args .logging_steps / (time .time () - tic_eval ))
357
+ )
358
+ tic_eval = time .time ()
368
359
369
360
if not args .cloze_eval :
370
361
total_loss = float (total_score )
0 commit comments