@@ -268,6 +268,24 @@ def _dynabert_init(self, model, eval_dataloader):
268
268
return ofa_model , teacher_model
269
269
270
270
271
+ def check_dynabert_config (net_config , width_mult ):
272
+ '''
273
+ Corrects net_config for OFA model if necessary.
274
+ '''
275
+ if 'electra.embeddings_project' in net_config :
276
+ net_config ["electra.embeddings_project" ]['expand_ratio' ] = 1.0
277
+ for key in net_config :
278
+ # Makes sure to expands the size of the last dim to `width_mult` for
279
+ # these Linear weights.
280
+ if 'q_proj' in key or 'k_proj' in key or 'v_proj' in key or 'linear1' in key :
281
+ net_config [key ]['expand_ratio' ] = width_mult
282
+ # Keeps the size of the last dim of these Linear weights same as
283
+ # before.
284
+ elif 'out_proj' in key or 'linear2' in key :
285
+ net_config [key ]['expand_ratio' ] = 1.0
286
+ return net_config
287
+
288
+
271
289
def _dynabert_training (self , ofa_model , model , teacher_model , train_dataloader ,
272
290
eval_dataloader , num_train_epochs ):
273
291
@@ -388,6 +406,7 @@ def evaluate_token_cls(model, data_loader):
388
406
# Step8: Broadcast supernet config from width_mult,
389
407
# and use this config in supernet training.
390
408
net_config = utils .dynabert_config (ofa_model , width_mult )
409
+ net_config = check_dynabert_config (net_config , width_mult )
391
410
ofa_model .set_net_config (net_config )
392
411
if "token_type_ids" in batch :
393
412
logits , teacher_logits = ofa_model (
@@ -424,6 +443,7 @@ def evaluate_token_cls(model, data_loader):
424
443
if global_step % self .args .save_steps == 0 :
425
444
for idx , width_mult in enumerate (self .args .width_mult_list ):
426
445
net_config = utils .dynabert_config (ofa_model , width_mult )
446
+ net_config = check_dynabert_config (net_config , width_mult )
427
447
ofa_model .set_net_config (net_config )
428
448
tic_eval = time .time ()
429
449
logger .info ("width_mult %s:" % round (width_mult , 2 ))
@@ -479,6 +499,7 @@ def _dynabert_export(self, ofa_model):
479
499
origin_model = self .model .__class__ .from_pretrained (model_dir )
480
500
ofa_model .model .set_state_dict (state_dict )
481
501
best_config = utils .dynabert_config (ofa_model , width_mult )
502
+ best_config = check_dynabert_config (best_config , width_mult )
482
503
origin_model_new = ofa_model .export (best_config ,
483
504
input_shapes = [[1 , 1 ], [1 , 1 ]],
484
505
input_dtypes = ['int64' , 'int64' ],
@@ -561,7 +582,9 @@ def _batch_generator_func():
561
582
optimize_model = False )
562
583
post_training_quantization .quantize ()
563
584
post_training_quantization .save_quantized_model (
564
- save_model_path = os .path .join (model_dir , algo + str (batch_size )),
585
+ save_model_path = os .path .join (
586
+ model_dir , algo +
587
+ "_" .join ([str (batch_size ), str (batch_nums )])),
565
588
model_filename = args .output_filename_prefix + ".pdmodel" ,
566
589
params_filename = args .output_filename_prefix + ".pdiparams" )
567
590
@@ -632,6 +655,8 @@ def auto_model_forward(self,
632
655
embedding_kwargs ["input_ids" ] = input_ids
633
656
634
657
embedding_output = self .embeddings (** embedding_kwargs )
658
+ if hasattr (self , "embeddings_project" ):
659
+ embedding_output = self .embeddings_project (embedding_output )
635
660
636
661
self .encoder ._use_cache = use_cache # To be consistent with HF
637
662
0 commit comments