@@ -248,16 +248,19 @@ def _preprocess(self, source):
248
248
def _infer (self , inputs ):
249
249
raise NotImplementedError
250
250
251
- def _postprocess (self , predictions ):
251
+ def _postprocess (self , predictions , return_tokens = False ):
252
252
decoded_predictions = self .tokenizer .batch_decode (
253
253
predictions , skip_special_tokens = True , clean_up_tokenization_spaces = False
254
254
)
255
- return decoded_predictions
255
+ if return_tokens :
256
+ return decoded_predictions , predictions
257
+ else :
258
+ return decoded_predictions
256
259
257
- def predict (self , input_texts : str | list [str ]):
260
+ def predict (self , input_texts : str | list [str ], return_tokens = False ):
258
261
tokenized_source = self ._preprocess (input_texts )
259
262
predictions = self ._infer (tokenized_source )
260
- decoded_predictions = self ._postprocess (predictions )
263
+ decoded_predictions = self ._postprocess (predictions , return_tokens = return_tokens )
261
264
return decoded_predictions
262
265
263
266
@@ -470,13 +473,16 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
470
473
)
471
474
self .generation_config = None
472
475
473
- def _postprocess (self , predictions ):
476
+ def _postprocess (self , predictions , return_tokens = False ):
474
477
if paddle .distributed .get_rank () == 0 :
475
478
tokens : np .ndarray = load_real_time_tokens ()
476
479
decoded_predictions = self .tokenizer .batch_decode (
477
480
tokens .tolist (), skip_special_tokens = True , clean_up_tokenization_spaces = False
478
481
)
479
- return decoded_predictions
482
+ if return_tokens :
483
+ return decoded_predictions , tokens .tolist ()
484
+ else :
485
+ return decoded_predictions
480
486
else :
481
487
return None
482
488
@@ -1034,7 +1040,7 @@ def _infer(self, inputs: dict[str, paddle.Tensor]):
1034
1040
)
1035
1041
1036
1042
@paddle .no_grad ()
1037
- def predict (self , input_texts : str | list [str ]):
1043
+ def predict (self , input_texts : str | list [str ], return_tokens = False ):
1038
1044
self ._preprocess (input_texts )
1039
1045
1040
1046
result_queue = mp .Queue ()
@@ -1055,9 +1061,15 @@ def predict(self, input_texts: str | list[str]):
1055
1061
self .used_list [i ] = []
1056
1062
1057
1063
outputs = []
1064
+ output_tokens = []
1058
1065
while len (outputs ) < self .batch_size :
1059
- outputs .append (result_queue .get (timeout = 1 )[- 1 ])
1060
- return outputs
1066
+ result = result_queue .get (timeout = 1 )
1067
+ outputs .append (result [- 1 ])
1068
+ output_tokens .append (result [- 2 ])
1069
+ if return_tokens :
1070
+ return outputs , output_tokens
1071
+ else :
1072
+ return outputs
1061
1073
1062
1074
1063
1075
class StaticBlockInferencePredictor (BlockInferencePredictorMixin , BasePredictor ):
@@ -1180,7 +1192,7 @@ def _share_data(self):
1180
1192
def _infer (self ):
1181
1193
self .predictor .run ()
1182
1194
1183
- def predict (self , input_texts : str | list [str ]):
1195
+ def predict (self , input_texts : str | list [str ], return_tokens = False ):
1184
1196
1185
1197
s_time = time .time ()
1186
1198
self ._preprocess (input_texts )
@@ -1213,9 +1225,15 @@ def predict(self, input_texts: str | list[str]):
1213
1225
self .used_list [i ] = []
1214
1226
1215
1227
outputs = []
1228
+ output_tokens = []
1216
1229
while len (outputs ) < self .batch_size :
1217
- outputs .append (result_queue .get (timeout = 1 )[- 1 ])
1218
- return outputs
1230
+ result = result_queue .get (timeout = 1 )
1231
+ outputs .append (result [- 1 ])
1232
+ output_tokens .append (result [- 2 ])
1233
+ if return_tokens :
1234
+ return outputs , output_tokens
1235
+ else :
1236
+ return outputs
1219
1237
1220
1238
def _preprocess (self , source ):
1221
1239
BlockInferencePredictorMixin ._preprocess (self , source )
@@ -1681,8 +1699,8 @@ def benchmark(predictor, predictor_args, model_args):
1681
1699
output_tokens = 0
1682
1700
for _ in range (test_time ):
1683
1701
for bs , batch_source_text in enumerate (batch_benchmark_texts ):
1684
- outputs = predictor .predict (batch_source_text )
1685
- output_tokens += sum ([len (output ) for output in outputs ])
1702
+ outputs , batch_tokens = predictor .predict (batch_source_text , return_tokens = True )
1703
+ output_tokens += sum ([len (tokens ) for tokens in batch_tokens ])
1686
1704
end = time .perf_counter ()
1687
1705
print ("Avg Elapse time is: " , (end - start ) / test_time )
1688
1706
print ("Output tokens is: " , output_tokens )
0 commit comments