@@ -57,8 +57,8 @@ def get_parser():
57
57
"--device" ,
58
58
type = str ,
59
59
default = "gpu" ,
60
- choices = ["cpu" , "eval_pathgpu " , "xpu" , "npu" ],
61
- help = "select cpu, gpu, xpu devices." ,
60
+ choices = ["cpu" , "gpu " , "xpu" , "npu" , "gcu " ],
61
+ help = "select cpu, gpu, xpu, gcu devices." ,
62
62
)
63
63
parser .add_argument (
64
64
"--dtype" ,
@@ -67,7 +67,12 @@ def get_parser():
67
67
choices = ["bfloat16" , "float16" , "float32" ],
68
68
help = "set the dtype of model" ,
69
69
)
70
-
70
+ parser .add_argument (
71
+ "--use_flash_attention" ,
72
+ type = bool ,
73
+ default = False ,
74
+ help = "Whether to use flash attention" ,
75
+ )
71
76
# load autodist name files, eg: bloom-176b
72
77
parser .add_argument ("--load_autodist" , action = "store_true" , help = "whether load auto-dist wieght file" )
73
78
@@ -244,7 +249,8 @@ def get_tokens(tokenizer, text, strict=True):
244
249
last_token = text .split ()[- 1 ]
245
250
start_idx = text .rfind (last_token )
246
251
beginning_tokens = tokenizer (text [:start_idx ].strip ())["input_ids" ]
247
- last_token = tokenizer (" " + last_token )["input_ids" ]
252
+ all_tokens = tokenizer (text .strip ())["input_ids" ]
253
+ last_token = all_tokens [len (beginning_tokens ) :]
248
254
return beginning_tokens , last_token
249
255
250
256
@@ -271,7 +277,7 @@ def create_eval_dataset(args):
271
277
with open (args .eval_path , "r" ) as f :
272
278
for line in f .readlines ():
273
279
text = json .loads (line )["text" ]
274
- tokens , labels = get_tokens (tokenizer , text , strict = False )
280
+ tokens , labels = get_tokens (tokenizer , text , strict = True )
275
281
tokenized_data .append (tokens )
276
282
tokenized_label .append (labels )
277
283
val_dataset = Lambada_Eval_Dataset (tokenized_data , tokenized_label , seq_len , tokenizer .pad_token_id )
@@ -316,49 +322,40 @@ def do_generation():
316
322
tensor_parallel_output = False ,
317
323
tensor_parallel_degree = args .tensor_parallel_degree ,
318
324
tensor_parallel_rank = paddle .distributed .get_rank (),
319
- use_flash_attention = False ,
325
+ use_flash_attention = args . use_flash_attention ,
320
326
dtype = args .dtype , # todo enable set dtype to avoid additional mem usage
321
327
)
322
328
323
329
model .eval ()
324
- args .use_pure_fp16 = False
325
-
326
330
total_score = 0
327
331
score_name = "loss" if not args .cloze_eval else "number correct"
328
- args .use_pure_fp16 = False
329
332
eval_data_loader = create_eval_dataset (args )
330
333
with paddle .no_grad ():
331
334
for step , batch in enumerate (eval_data_loader ):
332
335
333
336
tokens , loss_mask = batch [:2 ]
334
337
labels = batch [- 1 ]
335
- with paddle .amp .auto_cast (args .use_pure_fp16 ):
336
- if args .model_type == "bloom" :
337
- preds = model (tokens ).detach ()
338
- else :
339
- preds = model (tokens )[0 ].detach ()
340
- # print(preds)
341
-
342
- # cast preds to float32 to keep high-precision
343
- preds = preds .astype (paddle .float32 )
344
-
345
- if not args .cloze_eval :
346
- masked_lm_loss = paddle .nn .functional .cross_entropy (preds , labels , reduction = "none" )
347
- loss = paddle .sum (masked_lm_loss * loss_mask )
348
- total_score += float (loss ) / (args .num_tokenized_tokens - 1 )
349
- else :
350
- outputs = paddle .argmax (preds , - 1 )
351
- acc = paddle .cast (outputs == labels , "float32" )
352
- acc = paddle .where (paddle .cast (loss_mask , "bool" ), acc , paddle .ones_like (acc ))
353
- acc = paddle .sum (paddle .prod (acc , - 1 ))
354
- total_score += float (acc )
355
-
356
- if step % args .logging_steps == 0 :
357
- logger .info (
358
- "step %d, batch: %d, %s: %f, speed: %.2f step/s"
359
- % (step , step , score_name , total_score , args .logging_steps / (time .time () - tic_eval ))
360
- )
361
- tic_eval = time .time ()
338
+ preds = model (tokens , return_dict = True ).logits .detach ()
339
+ # cast preds to float32 to keep high-precision
340
+ preds = preds .astype (paddle .float32 )
341
+
342
+ if not args .cloze_eval :
343
+ masked_lm_loss = paddle .nn .functional .cross_entropy (preds , labels , reduction = "none" )
344
+ loss = paddle .sum (masked_lm_loss * loss_mask )
345
+ total_score += float (loss ) / (args .num_tokenized_tokens - 1 )
346
+ else :
347
+ outputs = paddle .argmax (preds , - 1 )
348
+ acc = paddle .cast (outputs == labels , "float32" )
349
+ acc = paddle .where (paddle .cast (loss_mask , "bool" ), acc , paddle .ones_like (acc ))
350
+ acc = paddle .sum (paddle .prod (acc , - 1 ))
351
+ total_score += float (acc )
352
+
353
+ if step % args .logging_steps == 0 :
354
+ logger .info (
355
+ "step %d, batch: %d, %s: %f, speed: %.2f step/s"
356
+ % (step , step , score_name , total_score , args .logging_steps / (time .time () - tic_eval ))
357
+ )
358
+ tic_eval = time .time ()
362
359
363
360
if not args .cloze_eval :
364
361
total_loss = float (total_score )
@@ -381,4 +378,4 @@ def do_generation():
381
378
382
379
383
380
if __name__ == "__main__" :
384
- do_generation ()
381
+ do_generation ()
0 commit comments