Skip to content

Commit 433e547

Browse files
committed
add convert slow tokenizer method
1 parent 4ab5392 commit 433e547

File tree

1 file changed

+324
-0
lines changed

1 file changed

+324
-0
lines changed
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
from typing import Dict, List, Optional, Tuple
19+
20+
import tokenizers
21+
from packaging import version
22+
from tokenizers import (
23+
AddedToken,
24+
Regex,
25+
Tokenizer,
26+
decoders,
27+
normalizers,
28+
pre_tokenizers,
29+
)
30+
from tokenizers.models import BPE, Unigram
31+
32+
33+
# Copied from transformers, adapted for tokenizers >= 0.19.0
34+
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
35+
if add_prefix_space:
36+
prepend_scheme = "always"
37+
if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
38+
prepend_scheme = "first"
39+
else:
40+
prepend_scheme = "never"
41+
return prepend_scheme
42+
43+
44+
# Extract the vocab and merge file from sentencepiece file
45+
class SentencePieceExtractor:
46+
def __init__(self, model: str):
47+
from sentencepiece import SentencePieceProcessor
48+
49+
self.sp = SentencePieceProcessor()
50+
self.sp.Load(model)
51+
52+
def extract(self, vocab_scores: Optional[Tuple[str, float]] = None) -> Tuple[Dict[str, int], List[Tuple]]:
53+
sp = self.sp
54+
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
55+
if vocab_scores is not None:
56+
vocab_scores, reverse = dict(vocab_scores), True
57+
else:
58+
vocab_scores, reverse = vocab, False
59+
60+
# Merges
61+
merges = []
62+
for merge, piece_score in vocab_scores.items():
63+
local = []
64+
for index in range(1, len(merge)):
65+
piece_l, piece_r = merge[:index], merge[index:]
66+
if piece_l in vocab and piece_r in vocab:
67+
local.append((piece_l, piece_r, piece_score))
68+
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
69+
merges.extend(local)
70+
71+
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
72+
merges = [(val[0], val[1]) for val in merges]
73+
74+
return vocab, merges
75+
76+
77+
def check_number_comma(piece: str) -> bool:
78+
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
79+
80+
81+
class Converter:
82+
def __init__(self, original_tokenizer):
83+
self.original_tokenizer = original_tokenizer
84+
85+
def converted(self) -> Tokenizer:
86+
raise NotImplementedError()
87+
88+
89+
class SpmConverter(Converter):
90+
def __init__(self, *args):
91+
92+
super().__init__(*args)
93+
94+
from . import sentencepiece_model_pb2 as model_pb2
95+
96+
m = model_pb2.ModelProto()
97+
if hasattr(self.original_tokenizer, "sentencepiece_model_file"):
98+
spm_vocab_file = self.original_tokenizer.sentencepiece_model_file
99+
else:
100+
spm_vocab_file = self.original_tokenizer.vocab_file
101+
with open(spm_vocab_file, "rb") as f:
102+
m.ParseFromString(f.read())
103+
self.proto = m
104+
105+
if self.proto.trainer_spec.byte_fallback:
106+
if not getattr(self, "handle_byte_fallback", None):
107+
import warnings
108+
109+
warnings.warn(
110+
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
111+
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
112+
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
113+
"unknown tokens into a sequence of byte tokens matching the original piece of text."
114+
)
115+
116+
def vocab(self, proto):
117+
return [(piece.piece, piece.score) for piece in proto.pieces]
118+
119+
def unk_id(self, proto):
120+
return proto.trainer_spec.unk_id
121+
122+
def tokenizer(self, proto):
123+
model_type = proto.trainer_spec.model_type
124+
vocab_scores = self.vocab(proto)
125+
unk_id = self.unk_id(proto)
126+
127+
if model_type == 1:
128+
tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
129+
elif model_type == 2:
130+
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
131+
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
132+
tokenizer = Tokenizer(
133+
BPE(
134+
bpe_vocab,
135+
merges,
136+
unk_token=proto.trainer_spec.unk_piece,
137+
fuse_unk=True,
138+
)
139+
)
140+
else:
141+
raise Exception(
142+
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
143+
)
144+
145+
return tokenizer
146+
147+
def normalizer(self, proto):
148+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
149+
_normalizers = [
150+
normalizers.Strip(left=False, right=True), # stripping is important
151+
normalizers.Replace(Regex(" {2,}"), "▁"),
152+
]
153+
if not precompiled_charsmap:
154+
return normalizers.Sequence(_normalizers)
155+
else:
156+
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
157+
158+
def pre_tokenizer(self, replacement, add_prefix_space):
159+
prepend_scheme = "always"
160+
if hasattr(self.original_tokenizer, "legacy") and not self.original_tokenizer.legacy:
161+
prepend_scheme = "first"
162+
if version.parse(tokenizers.__version__) >= version.parse("0.19.0"):
163+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
164+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
165+
else:
166+
return pre_tokenizers.Metaspace(
167+
replacement=replacement, add_prefix_space=add_prefix_space, prepend_scheme=prepend_scheme
168+
)
169+
170+
def post_processor(self):
171+
return None
172+
173+
def decoder(self, replacement, add_prefix_space):
174+
if version.parse(tokenizers.__version__) >= version.parse("0.19.0"):
175+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
176+
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
177+
else:
178+
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
179+
180+
def converted(self) -> Tokenizer:
181+
tokenizer = self.tokenizer(self.proto)
182+
183+
# Tokenizer assemble
184+
normalizer = self.normalizer(self.proto)
185+
if normalizer is not None:
186+
tokenizer.normalizer = normalizer
187+
188+
replacement = "▁"
189+
add_prefix_space = True
190+
pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
191+
if pre_tokenizer is not None:
192+
tokenizer.pre_tokenizer = pre_tokenizer
193+
194+
tokenizer.decoder = self.decoder(replacement, add_prefix_space)
195+
post_processor = self.post_processor()
196+
if post_processor:
197+
tokenizer.post_processor = post_processor
198+
199+
return tokenizer
200+
201+
202+
class TikTokenConverter(Converter):
203+
def extract(self, tiktoken_file: str):
204+
from .tiktoken_model_utils import bpe, bytes_to_unicode, load_tiktoken_bpe
205+
206+
bpe_ranks = (
207+
self.original_tokenizer.mergeable_ranks
208+
if hasattr(self.original_tokenizer, "mergeable_ranks") and self.original_tokenizer.mergeable_ranks
209+
else load_tiktoken_bpe(tiktoken_file)
210+
)
211+
byte_encoder = bytes_to_unicode()
212+
213+
def token_bytes_to_string(b):
214+
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
215+
216+
merges = []
217+
vocab = {}
218+
for token, rank in bpe_ranks.items():
219+
vocab[token_bytes_to_string(token)] = rank
220+
if len(token) == 1:
221+
continue
222+
merged = tuple(bpe(bpe_ranks, token, max_rank=rank))
223+
if len(merged) == 2:
224+
merges.append(tuple(map(token_bytes_to_string, merged)))
225+
226+
return vocab, merges
227+
228+
229+
class LlamaConverter(SpmConverter):
230+
handle_byte_fallback = True
231+
232+
def vocab(self, proto):
233+
vocab = [
234+
("<unk>", 0.0),
235+
("<s>", 0.0),
236+
("</s>", 0.0),
237+
]
238+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
239+
return vocab
240+
241+
def unk_id(self, proto):
242+
return 0
243+
244+
def decoder(self, replacement, add_prefix_space):
245+
return decoders.Sequence(
246+
[
247+
decoders.Replace("▁", " "),
248+
decoders.ByteFallback(),
249+
decoders.Fuse(),
250+
decoders.Strip(content=" ", left=1),
251+
]
252+
)
253+
254+
def tokenizer(self, proto):
255+
model_type = proto.trainer_spec.model_type
256+
vocab_scores = self.vocab(proto)
257+
if model_type == 1:
258+
259+
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
260+
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
261+
else:
262+
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
263+
264+
elif model_type == 2:
265+
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
266+
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
267+
tokenizer = Tokenizer(
268+
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
269+
)
270+
tokenizer.add_special_tokens(
271+
[
272+
AddedToken("<unk>", normalized=False, special=True),
273+
AddedToken("<s>", normalized=False, special=True),
274+
AddedToken("</s>", normalized=False, special=True),
275+
]
276+
)
277+
else:
278+
raise Exception(
279+
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
280+
)
281+
282+
return tokenizer
283+
284+
def normalizer(self, proto):
285+
return normalizers.Sequence(
286+
[
287+
normalizers.Prepend(prepend="▁"),
288+
normalizers.Replace(pattern=" ", content="▁"),
289+
]
290+
)
291+
292+
def pre_tokenizer(self, replacement, add_prefix_space):
293+
return None
294+
295+
296+
SLOW_TO_FAST_CONVERTERS = {
297+
"LlamaTokenizer": LlamaConverter,
298+
}
299+
300+
301+
def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
302+
"""
303+
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
304+
305+
Args:
306+
transformer_tokenizer ([`~tokenizer_utils_base.PretrainedTokenizer`]):
307+
Instance of a slow tokenizer to convert in the backend tokenizer for
308+
[`~tokenizer_utils_base.PretrainedTokenizerFast`].
309+
310+
Return:
311+
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
312+
[`~tokenizer_utils_base.PretrainedTokenizerFast`]
313+
"""
314+
315+
tokenizer_class_name = transformer_tokenizer.__class__.__name__
316+
if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
317+
raise ValueError(
318+
f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance. "
319+
f"No converter was found. Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
320+
)
321+
322+
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
323+
324+
return converter_class(transformer_tokenizer).converted()

0 commit comments

Comments
 (0)