Skip to content

Commit ab67ff2

Browse files
authored
[FastTokenizer] Fix fast_tokenizer import (#4126)
* Fix fast_tokenizer import * use import_module instead of importlib.import_module * Add auto tokenizer unittest * update to __internal_testing__ * Add test
1 parent f96e787 commit ab67ff2

File tree

3 files changed

+74
-7
lines changed

3 files changed

+74
-7
lines changed

paddlenlp/transformers/auto/tokenizer.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from paddlenlp import __version__
2424
from paddlenlp.utils.downloader import COMMUNITY_MODEL_PREFIX, get_path_from_url
2525
from paddlenlp.utils.env import HF_CACHE_HOME, MODEL_HOME
26-
from paddlenlp.utils.import_utils import is_fast_tokenizer_available
26+
from paddlenlp.utils.import_utils import import_module, is_fast_tokenizer_available
2727
from paddlenlp.utils.log import logger
2828

2929
__all__ = [
@@ -154,13 +154,31 @@ def _get_tokenizer_class_from_config(cls, pretrained_model_name_or_path, config_
154154

155155
if init_class:
156156
class_name = cls._name_mapping[init_class]
157-
import_class = importlib.import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
157+
import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
158158
tokenizer_class = getattr(import_class, init_class)
159159
if use_fast:
160-
for fast_tokenizer_class, name in cls._fast_name_mapping.items():
161-
if name == class_name:
162-
import_class = importlib.import_module(f"paddlenlp.transformers.{class_name}.fast_tokenizer")
163-
tokenizer_class = getattr(import_class, fast_tokenizer_class)
160+
if is_fast_tokenizer_available():
161+
is_support_fast_tokenizer = False
162+
init_class_prefix = init_class[:-9]
163+
for fast_tokenizer_class, name in cls._fast_name_mapping.items():
164+
fast_tokenizer_class_prefix = fast_tokenizer_class[:-9]
165+
if name == class_name and fast_tokenizer_class_prefix.startswith(init_class_prefix):
166+
is_support_fast_tokenizer = True
167+
import_class = import_module(f"paddlenlp.transformers.{class_name}.fast_tokenizer")
168+
tokenizer_class = getattr(import_class, fast_tokenizer_class)
169+
break
170+
if not is_support_fast_tokenizer:
171+
logger.warning(
172+
f"The tokenizer {tokenizer_class} doesn't have the fast version."
173+
" Please check the map `paddlenlp.transformers.auto.tokenizer.FAST_TOKENIZER_MAPPING_NAMES`"
174+
" to see which fast tokenizers are currently supported."
175+
)
176+
else:
177+
logger.warning(
178+
"Can't find the fast_tokenizer package, "
179+
"please ensure install fast_tokenizer correctly. "
180+
"You can install fast_tokenizer by `pip install fast-tokenizer-python`."
181+
)
164182
return tokenizer_class
165183
# If no `init_class`, we use pattern recognition to recognize the tokenizer class.
166184
else:
@@ -170,7 +188,7 @@ def _get_tokenizer_class_from_config(cls, pretrained_model_name_or_path, config_
170188
if pattern in pretrained_model_name_or_path.lower():
171189
init_class = key
172190
class_name = cls._name_mapping[init_class]
173-
import_class = importlib.import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
191+
import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
174192
tokenizer_class = getattr(import_class, init_class)
175193
return tokenizer_class
176194

tests/transformers/auto/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2019 Hugging Face inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import paddlenlp
19+
from paddlenlp.transformers import AutoTokenizer, is_fast_tokenizer_available
20+
21+
22+
class AutoTokenizerTest(unittest.TestCase):
23+
def test_fast_tokenizer_import(self):
24+
tokenizer1 = AutoTokenizer.from_pretrained("__internal_testing__/bert", use_fast=False)
25+
self.assertIsInstance(tokenizer1, paddlenlp.transformers.BertTokenizer)
26+
27+
tokenizer2 = AutoTokenizer.from_pretrained("__internal_testing__/bert", use_fast=True)
28+
if is_fast_tokenizer_available():
29+
self.assertIsInstance(tokenizer2, paddlenlp.transformers.BertFastTokenizer)
30+
else:
31+
self.assertIsInstance(tokenizer2, paddlenlp.transformers.BertTokenizer)
32+
33+
def test_fast_tokenizer_non_exist(self):
34+
tokenizer1 = AutoTokenizer.from_pretrained("t5-small", use_fast=True)
35+
# T5 FastTokenizer doesn't exist yet, so from_pretrained will return the normal tokenizer.
36+
self.assertIsInstance(tokenizer1, paddlenlp.transformers.T5Tokenizer)

0 commit comments

Comments
 (0)