Skip to content

Commit 55fc1af

Browse files
author
gongenlei
authored
Add BART converter and Chinese BART models (#4636)
* add bart converter and transfer chinese models * update log
1 parent e4bd4f3 commit 55fc1af

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from collections import OrderedDict
17+
18+
import numpy as np
19+
import paddle
20+
import torch
21+
from transformers import BartForConditionalGeneration as hf_BartForConditionalGeneration
22+
23+
from paddlenlp.transformers import (
24+
BartForConditionalGeneration as pp_BartForConditionalGeneration,
25+
)
26+
from paddlenlp.utils import load_torch
27+
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
28+
from paddlenlp.utils.log import logger
29+
30+
# Download huggingface models
31+
hf_hub_repo = "fnlp/bart-base-chinese"
32+
base_url = f"https://huggingface.co/{hf_hub_repo}/resolve/main/"
33+
34+
pp_hf_checkpoint = hf_hub_repo.replace("/", "_")
35+
os.makedirs(pp_hf_checkpoint, exist_ok=True)
36+
37+
for i in [
38+
"config.json",
39+
"vocab.txt",
40+
"tokenizer_config.json",
41+
"special_tokens_map.json",
42+
"pytorch_model.bin",
43+
"added_tokens.json",
44+
"spiece.model",
45+
]:
46+
try:
47+
get_path_from_url_with_filelock(f"{base_url}{i}", pp_hf_checkpoint)
48+
except RuntimeError:
49+
logger.warning(f"{base_url}{i} not found.")
50+
51+
use_torch = False
52+
try:
53+
hf_model = load_torch(os.path.join(pp_hf_checkpoint, "pytorch_model.bin"))
54+
except ValueError:
55+
# Some models coming from pytorch_lighting
56+
use_torch = True
57+
hf_model = torch.load(os.path.join(pp_hf_checkpoint, "pytorch_model.bin"), map_location="cpu")
58+
59+
huggingface_to_paddle_encoder = {
60+
"model.encoder.embed_tokens": "bart.encoder.embed_tokens",
61+
"model.encoder.embed_positions": "bart.encoder.encoder_embed_positions",
62+
"model.encoder.layernorm_embedding": "bart.encoder.encoder_layernorm_embedding",
63+
".self_attn_layer_norm.": ".norm1.",
64+
".fc1.": ".linear1.",
65+
".fc2.": ".linear2.",
66+
".final_layer_norm.": ".norm2.",
67+
"model.encoder": "bart.encoder.encoder",
68+
}
69+
70+
huggingface_to_paddle_decoder = {
71+
"model.decoder.embed_tokens": "bart.decoder.embed_tokens",
72+
"model.decoder.embed_positions": "bart.decoder.decoder_embed_positions",
73+
"model.decoder.layernorm_embedding": "bart.decoder.decoder_layernorm_embedding",
74+
".self_attn_layer_norm.": ".norm1.",
75+
".encoder_attn.": ".cross_attn.",
76+
".encoder_attn_layer_norm.": ".norm2.",
77+
".fc1.": ".linear1.",
78+
".fc2.": ".linear2.",
79+
".final_layer_norm.": ".norm3.",
80+
"model.decoder": "bart.decoder.decoder",
81+
}
82+
83+
skip_weights = []
84+
85+
dont_transpose = [
86+
".embed_positions.weight",
87+
".embed_tokens.weight",
88+
"layernorm_embedding.weight",
89+
"norm.weight",
90+
".shared.weight",
91+
"lm_head.weight",
92+
]
93+
94+
paddle_state_dict = OrderedDict()
95+
96+
# Convert parameters
97+
for k, v in hf_model.items():
98+
transpose = False
99+
if k in skip_weights:
100+
continue
101+
if k[-7:] == ".weight":
102+
if not any([w in k for w in dont_transpose]):
103+
if v.ndim == 2:
104+
v = v.transpose(0, 1) if use_torch else v.transpose()
105+
transpose = True
106+
oldk = k
107+
108+
if "model.encoder." in k:
109+
for huggingface_name, paddle_name in huggingface_to_paddle_encoder.items():
110+
k = k.replace(huggingface_name, paddle_name)
111+
elif "model.decoder." in k:
112+
for huggingface_name, paddle_name in huggingface_to_paddle_decoder.items():
113+
k = k.replace(huggingface_name, paddle_name)
114+
115+
if oldk == "model.shared.weight":
116+
k = "bart.shared.weight"
117+
118+
if oldk == "lm_head.weight":
119+
k = "lm_head_weight"
120+
121+
logger.info(f"Converting: {oldk} => {k} | is_transpose {transpose}")
122+
123+
paddle_state_dict[k] = v.data.numpy() if use_torch else v
124+
125+
# Save to .pdparams
126+
paddle.save(paddle_state_dict, os.path.join(pp_hf_checkpoint, "model_state.pdparams"))
127+
128+
# Compare ppnlp with hf
129+
paddle.set_grad_enabled(False)
130+
torch.set_grad_enabled(False)
131+
pp_model = pp_BartForConditionalGeneration.from_pretrained(pp_hf_checkpoint)
132+
pp_model.eval()
133+
hf_model = hf_BartForConditionalGeneration.from_pretrained(pp_hf_checkpoint)
134+
hf_model.eval()
135+
136+
input_ids = np.random.randint(1, 10000, size=(2, 10))
137+
pp_inputs = paddle.to_tensor(input_ids)
138+
hf_inputs = torch.tensor(input_ids)
139+
140+
pp_output = pp_model(pp_inputs)
141+
hf_output = hf_model(hf_inputs)
142+
143+
diff = abs(hf_output.logits.detach().numpy() - pp_output.numpy())
144+
logger.info(f"max diff: {np.max(diff)}, min diff: {np.min(diff)}")

0 commit comments

Comments
 (0)