Skip to content

Commit f7c336e

Browse files
committed
supports electra
1 parent 938edda commit f7c336e

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

docs/compression.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ compression_args = parser.parse_args_into_dataclasses()
117117

118118
#### Trainer 实例化参数介绍
119119

120-
- **--model** 待压缩的模型,目前支持 ERNIE、BERT、RoBERTa、ERNIE-M、ERNIE-Gram、PP-MiniLM、TinyBERT 等结构相似的模型,是在下游任务中微调后的模型,当预训练模型选择 ERNIE 时,需要继承 `ErniePretrainedModel`。以分类任务为例,可通过`AutoModelForSequenceClassification.from_pretrained(model_name_or_path)` 等方式来获取,这种情况下,`model_name_or_path`目录下需要有 model_config.json, model_state.pdparams 文件;
120+
- **--model** 待压缩的模型,目前支持 ERNIE、BERT、RoBERTa、ERNIE-M、ELECTRA、ERNIE-Gram、PP-MiniLM、TinyBERT 等结构相似的模型,是在下游任务中微调后的模型,当预训练模型选择 ERNIE 时,需要继承 `ErniePretrainedModel`。以分类任务为例,可通过`AutoModelForSequenceClassification.from_pretrained(model_name_or_path)` 等方式来获取,这种情况下,`model_name_or_path`目录下需要有 model_config.json, model_state.pdparams 文件;
121121
- **--data_collator** 三类任务均可使用 PaddleNLP 预定义好的 [DataCollator 类](../../paddlenlp/data/data_collator.py)`data_collator` 可对数据进行 `Pad` 等操作。使用方法参考 [示例代码](../model_zoo/ernie-3.0/compress_seq_cls.py) 即可;
122122
- **--train_dataset** 裁剪训练需要使用的训练集,是任务相关的数据。自定义数据集的加载可参考 [文档](https://huggingface.co/docs/datasets/loading)。不启动裁剪时,可以为 None;
123123
- **--eval_dataset** 裁剪训练使用的评估集,也是量化使用的校准数据,是任务相关的数据。自定义数据集的加载可参考 [文档](https://huggingface.co/docs/datasets/loading)。是 Trainer 的必选参数;

paddlenlp/trainer/trainer_compress.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,24 @@ def _dynabert_init(self, model, eval_dataloader):
268268
return ofa_model, teacher_model
269269

270270

271+
def check_dynabert_config(net_config, width_mult):
272+
'''
273+
Corrects net_config for OFA model if necessary.
274+
'''
275+
if 'electra.embeddings_project' in net_config:
276+
net_config["electra.embeddings_project"]['expand_ratio'] = 1.0
277+
for key in net_config:
278+
# Makes sure to expands the size of the last dim to `width_mult` for
279+
# these Linear weights.
280+
if 'q_proj' in key or 'k_proj' in key or 'v_proj' in key or 'linear1' in key:
281+
net_config[key]['expand_ratio'] = width_mult
282+
# Keeps the size of the last dim of these Linear weights same as
283+
# before.
284+
elif 'out_proj' in key or 'linear2' in key:
285+
net_config[key]['expand_ratio'] = 1.0
286+
return net_config
287+
288+
271289
def _dynabert_training(self, ofa_model, model, teacher_model, train_dataloader,
272290
eval_dataloader, num_train_epochs):
273291

@@ -388,6 +406,7 @@ def evaluate_token_cls(model, data_loader):
388406
# Step8: Broadcast supernet config from width_mult,
389407
# and use this config in supernet training.
390408
net_config = utils.dynabert_config(ofa_model, width_mult)
409+
net_config = check_dynabert_config(net_config, width_mult)
391410
ofa_model.set_net_config(net_config)
392411
if "token_type_ids" in batch:
393412
logits, teacher_logits = ofa_model(
@@ -424,6 +443,7 @@ def evaluate_token_cls(model, data_loader):
424443
if global_step % self.args.save_steps == 0:
425444
for idx, width_mult in enumerate(self.args.width_mult_list):
426445
net_config = utils.dynabert_config(ofa_model, width_mult)
446+
net_config = check_dynabert_config(net_config, width_mult)
427447
ofa_model.set_net_config(net_config)
428448
tic_eval = time.time()
429449
logger.info("width_mult %s:" % round(width_mult, 2))
@@ -479,6 +499,7 @@ def _dynabert_export(self, ofa_model):
479499
origin_model = self.model.__class__.from_pretrained(model_dir)
480500
ofa_model.model.set_state_dict(state_dict)
481501
best_config = utils.dynabert_config(ofa_model, width_mult)
502+
best_config = check_dynabert_config(best_config, width_mult)
482503
origin_model_new = ofa_model.export(best_config,
483504
input_shapes=[[1, 1], [1, 1]],
484505
input_dtypes=['int64', 'int64'],
@@ -561,7 +582,9 @@ def _batch_generator_func():
561582
optimize_model=False)
562583
post_training_quantization.quantize()
563584
post_training_quantization.save_quantized_model(
564-
save_model_path=os.path.join(model_dir, algo + str(batch_size)),
585+
save_model_path=os.path.join(
586+
model_dir, algo +
587+
"_".join([str(batch_size), str(batch_nums)])),
565588
model_filename=args.output_filename_prefix + ".pdmodel",
566589
params_filename=args.output_filename_prefix + ".pdiparams")
567590

@@ -632,6 +655,8 @@ def auto_model_forward(self,
632655
embedding_kwargs["input_ids"] = input_ids
633656

634657
embedding_output = self.embeddings(**embedding_kwargs)
658+
if hasattr(self, "embeddings_project"):
659+
embedding_output = self.embeddings_project(embedding_output)
635660

636661
self.encoder._use_cache = use_cache # To be consistent with HF
637662

0 commit comments

Comments
 (0)