From 1ceb7f23c9a4dd52398506db8a2f0532f03b1ffe Mon Sep 17 00:00:00 2001 From: w5688414 Date: Tue, 7 May 2024 02:33:07 +0000 Subject: [PATCH] Fix fast tokenizer import error --- paddlenlp/transformers/auto/tokenizer.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py index f1fa2e7993b7..6a960c6f7fac 100644 --- a/paddlenlp/transformers/auto/tokenizer.py +++ b/paddlenlp/transformers/auto/tokenizer.py @@ -193,10 +193,23 @@ def _get_tokenizer_class_from_config(cls, pretrained_model_name_or_path, config_ if init_class in cls._name_mapping: class_name = cls._name_mapping[init_class] import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer") - tokenizer_class = getattr(import_class, init_class) - if use_fast: - fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name) - tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class + tokenizer_class = None + try: + if use_fast: + tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name) + except: + # use the non fast tokenizer as default + logger.warning( + "`use_fast` is set to `True` but the tokenizer class does not have a fast version. " + " Falling back to the slow version." + ) + try: + if tokenizer_class is None: + tokenizer_class = getattr(import_class, init_class) + except: + raise ValueError( + f"Tokenizer class {init_class} is not currently imported, if you use fast tokenizer, please set use_fast to True." + ) return tokenizer_class else: import_class = import_module("paddlenlp.transformers")