Skip to content

Commit 0f1cf73

Browse files
committed
Merge branch 'ernie_layout' of https://github.com/linjieccc/PaddleNLP into ernie_layout
2 parents 84fddad + 2c62652 commit 0f1cf73

23 files changed

+3334
-218
lines changed

README_en.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ For more usage please refer to [Taskflow Docs](./docs/model_zoo/taskflow.md).
8181

8282
#### 🀄 Comprehensive Chinese Transformer Models
8383

84-
We provide **45+** network architectures and over **500+** pretrained models. Not only includes all the SOTA model like ERNIE, PLATO and SKEP released by Baidu, but also integrates most of the high-quality Chinese pretrained model developed by other organizations. Use `AutoModel` API to **⚡SUPER FAST⚡** download pretrained mdoels of different architecture. We welcome all developers to contribute your Transformer models to PaddleNLP!
84+
We provide **45+** network architectures and over **500+** pretrained models. Not only includes all the SOTA model like ERNIE, PLATO and SKEP released by Baidu, but also integrates most of the high-quality Chinese pretrained model developed by other organizations. Use `AutoModel` API to **⚡SUPER FAST⚡** download pretrained models of different architecture. We welcome all developers to contribute your Transformer models to PaddleNLP!
8585

8686
```python
8787
from paddlenlp.transformers import *

examples/code_generation/codegen/run_clm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def do_train(args):
252252
block_size)
253253
dev_set = process_ds(dev_set, tokenizer, args.overwrite_cache, block_size)
254254

255-
batchify_fn = DataCollatorWithPadding(tokenizer)
255+
batchify_fn = DataCollatorWithPadding(tokenizer, return_attention_mask=True)
256256

257257
train_batch_sampler = DistributedBatchSampler(
258258
train_set, batch_size=args.train_batch_size, shuffle=True)
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
17+
import paddle
18+
from paddle.io import Dataset
19+
import json
20+
from paddlenlp.transformers.bert.tokenizer import BertTokenizer
21+
import collections
22+
from typing import Dict, List, Tuple
23+
import numpy as np
24+
25+
BiEncoderPassage = collections.namedtuple("BiEncoderPassage", ["text", "title"])
26+
27+
BiENcoderBatch = collections.namedtuple("BiEncoderInput", [
28+
"questions_ids",
29+
"question_segments",
30+
"context_ids",
31+
"ctx_segments",
32+
"is_positive",
33+
"hard_negatives",
34+
"encoder_type",
35+
])
36+
37+
38+
def normalize_question(question: str) -> str:
39+
question = question.replace("’", "'")
40+
return question
41+
42+
43+
def normalize_passage(ctx_text: str):
44+
ctx_text = ctx_text.replace("\n", " ").replace("’", "'")
45+
if ctx_text.startswith('"'):
46+
ctx_text = ctx_text[1:]
47+
if ctx_text.endswith('"'):
48+
ctx_text = ctx_text[:-1]
49+
return ctx_text
50+
51+
52+
class BiEncoderSample(object):
53+
query: str
54+
positive_passages: List[BiEncoderPassage]
55+
negative_passages: List[BiEncoderPassage]
56+
hard_negative_passages: List[BiEncoderPassage]
57+
58+
59+
class NQdataSetForDPR(Dataset):
60+
"""
61+
class for managing dataset
62+
"""
63+
64+
def __init__(self, dataPath, query_special_suffix=None):
65+
super(NQdataSetForDPR, self).__init__()
66+
self.data = self._read_json_data(dataPath)
67+
self.tokenizer = BertTokenizer
68+
self.query_special_suffix = query_special_suffix
69+
self.new_data = []
70+
for i in range(0, self.__len__()):
71+
self.new_data.append(self.__getitem__(i))
72+
73+
def _read_json_data(self, dataPath):
74+
results = []
75+
with open(dataPath, "r", encoding="utf-8") as f:
76+
print("Reading file %s" % dataPath)
77+
data = json.load(f)
78+
results.extend(data)
79+
print("Aggregated data size: {}".format(len(results)))
80+
return results
81+
82+
def __getitem__(self, index):
83+
json_sample_data = self.data[index]
84+
r = BiEncoderSample()
85+
r.query = self._porcess_query(json_sample_data["question"])
86+
87+
positive_ctxs = json_sample_data["positive_ctxs"]
88+
89+
negative_ctxs = json_sample_data[
90+
"negative_ctxs"] if "negative_ctxs" in json_sample_data else []
91+
hard_negative_ctxs = json_sample_data["hard_negative_ctxs"] if "hard_negative_ctxs" in json_sample_data else []
92+
93+
for ctx in positive_ctxs + negative_ctxs + hard_negative_ctxs:
94+
if "title" not in ctx:
95+
ctx["title"] = None
96+
97+
def create_passage(ctx):
98+
return BiEncoderPassage(normalize_passage(ctx["text"]),
99+
ctx["title"])
100+
101+
r.positive_passages = [create_passage(ctx) for ctx in positive_ctxs]
102+
r.negative_passages = [create_passage(ctx) for ctx in negative_ctxs]
103+
r.hard_negative_passages = [
104+
create_passage(ctx) for ctx in hard_negative_ctxs
105+
]
106+
107+
return r
108+
109+
def _porcess_query(self, query):
110+
query = normalize_question(query)
111+
112+
if self.query_special_suffix and not query.endswith(
113+
self.query_special_suffix):
114+
query += self.query_special_suffix
115+
116+
return query
117+
118+
def __len__(self):
119+
return len(self.data)
120+
121+
122+
class DataUtil():
123+
"""
124+
Class for working with datasets
125+
"""
126+
127+
def __init__(self):
128+
self.tensorizer = BertTensorizer()
129+
130+
def create_biencoder_input(self,
131+
samples: List[BiEncoderSample],
132+
inserted_title,
133+
num_hard_negatives=0,
134+
num_other_negatives=0,
135+
shuffle=True,
136+
shuffle_positives=False,
137+
hard_neg_positives=False,
138+
hard_neg_fallback=True,
139+
query_token=None):
140+
141+
question_tensors = []
142+
ctx_tensors = []
143+
positive_ctx_indices = []
144+
hard_neg_ctx_indices = []
145+
146+
for sample in samples:
147+
148+
if shuffle and shuffle_positives:
149+
positive_ctxs = sample.positive_passages
150+
positive_ctx = positive_ctxs[np.random.choice(
151+
len(positive_ctxs))]
152+
else:
153+
positive_ctx = sample.positive_passages[0]
154+
155+
neg_ctxs = sample.negative_passages
156+
hard_neg_ctxs = sample.hard_negative_passages
157+
question = sample.query
158+
159+
if shuffle:
160+
random.shuffle(neg_ctxs)
161+
random.shuffle(hard_neg_ctxs)
162+
163+
if hard_neg_fallback and len(hard_neg_ctxs) == 0:
164+
hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]
165+
166+
neg_ctxs = neg_ctxs[0:num_other_negatives]
167+
hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]
168+
169+
all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
170+
hard_negative_start_idx = 1
171+
hard_negative_end_idx = 1 + len(hard_neg_ctxs)
172+
173+
current_ctxs_len = len(ctx_tensors)
174+
175+
sample_ctxs_tensors = [
176+
self.tensorizer.text_to_tensor(
177+
ctx.text,
178+
title=ctx.title if (inserted_title and ctx.title) else None)
179+
for ctx in all_ctxs
180+
]
181+
182+
ctx_tensors.extend(sample_ctxs_tensors)
183+
positive_ctx_indices.append(current_ctxs_len)
184+
hard_neg_ctx_indices.append(i for i in range(
185+
current_ctxs_len + hard_negative_start_idx,
186+
current_ctxs_len + hard_negative_end_idx,
187+
))
188+
"""if query_token:
189+
if query_token == "[START_END]":
190+
query_span = _select_span
191+
else:
192+
question_tensors.append(self.tensorizer.text_to_tensor(" ".join([query_token, question])))
193+
else:"""
194+
195+
question_tensors.append(self.tensorizer.text_to_tensor(question))
196+
197+
ctxs_tensor = paddle.concat(
198+
[paddle.reshape(ctx, [1, -1]) for ctx in ctx_tensors], axis=0)
199+
questions_tensor = paddle.concat(
200+
[paddle.reshape(q, [1, -1]) for q in question_tensors], axis=0)
201+
202+
ctx_segments = paddle.zeros_like(ctxs_tensor)
203+
question_segments = paddle.zeros_like(questions_tensor)
204+
205+
return BiENcoderBatch(
206+
questions_tensor,
207+
question_segments,
208+
ctxs_tensor,
209+
ctx_segments,
210+
positive_ctx_indices,
211+
hard_neg_ctx_indices,
212+
"question",
213+
)
214+
215+
216+
class BertTensorizer():
217+
218+
def __init__(self, pad_to_max=True, max_length=256):
219+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
220+
self.max_length = max_length
221+
self.pad_to_max = pad_to_max
222+
223+
def text_to_tensor(
224+
self,
225+
text: str,
226+
title=None,
227+
):
228+
text = text.strip()
229+
230+
if title:
231+
token_ids = self.tokenizer.encode(
232+
text,
233+
text_pair=title,
234+
max_seq_len=self.max_length,
235+
pad_to_max_seq_len=False,
236+
truncation_strategy="longest_first",
237+
)["input_ids"]
238+
else:
239+
token_ids = self.tokenizer.encode(
240+
text,
241+
max_seq_len=self.max_length,
242+
pad_to_max_seq_len=False,
243+
truncation_strategy="longest_first",
244+
)["input_ids"]
245+
246+
seq_len = self.max_length
247+
if self.pad_to_max and len(token_ids) < seq_len:
248+
token_ids = token_ids + [self.tokenizer.pad_token_type_id
249+
] * (seq_len - len(token_ids))
250+
if len(token_ids) >= seq_len:
251+
token_ids = token_ids[0:seq_len]
252+
token_ids[-1] = 102
253+
254+
return paddle.to_tensor(token_ids)

0 commit comments

Comments
 (0)