diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index 3a4def7db46c..df7a22a0cb95 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -45,6 +45,7 @@ AutoConfig, AutoModelForCausalLM, AutoTokenizer, + Llama3Tokenizer, LlamaTokenizer, ) from paddlenlp.utils.log import logger @@ -232,7 +233,7 @@ def neft_post_hook(module, input, output): if tokenizer.chat_template is not None: data_args.eval_with_do_generation = False - if isinstance(tokenizer, LlamaTokenizer): + if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, Llama3Tokenizer): tokenizer.pad_token_id = tokenizer.eos_token_id if data_args.dataset_name_or_path is None: diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py index 1e6e1215fe5d..f1fa2e7993b7 100644 --- a/paddlenlp/transformers/auto/tokenizer.py +++ b/paddlenlp/transformers/auto/tokenizer.py @@ -190,13 +190,20 @@ def _get_tokenizer_class_from_config(cls, pretrained_model_name_or_path, config_ init_class = init_kwargs.pop("tokenizer_class", None) if init_class: - class_name = cls._name_mapping[init_class] - import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer") - tokenizer_class = getattr(import_class, init_class) - if use_fast: - fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name) - tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class - return tokenizer_class + if init_class in cls._name_mapping: + class_name = cls._name_mapping[init_class] + import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer") + tokenizer_class = getattr(import_class, init_class) + if use_fast: + fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name) + tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class + return tokenizer_class + else: + import_class = import_module("paddlenlp.transformers") + tokenizer_class = getattr(import_class, init_class, None) + assert tokenizer_class is not None, f"Can't find tokenizer {init_class}" + return tokenizer_class + # If no `init_class`, we use pattern recognition to recognize the tokenizer class. else: # TODO: Potential issue https://github.com/PaddlePaddle/PaddleNLP/pull/3786#discussion_r1024689810 diff --git a/paddlenlp/transformers/llama/configuration.py b/paddlenlp/transformers/llama/configuration.py index 68459f025fe4..e0b051b7434f 100644 --- a/paddlenlp/transformers/llama/configuration.py +++ b/paddlenlp/transformers/llama/configuration.py @@ -147,6 +147,7 @@ def __init__( num_key_value_heads=None, initializer_range=0.02, rms_norm_eps=1e-6, + rope_theta=10000.0, use_cache=True, use_recompute=False, recompute_granularity="full", @@ -188,6 +189,7 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta self.use_cache = use_cache self.use_recompute = use_recompute diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index d70e63ffa484..379bd3f8874d 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -813,24 +813,28 @@ def _init_rope(self): self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, ) elif self.config.rope_scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, ) elif self.config.rope_scaling_type == "ntk": self.rotary_emb = LlamaNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, ) elif self.config.rope_scaling_type == "dynamic_ntk": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}") @@ -903,6 +907,7 @@ def forward( query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + if self.reshard_layer is not None: if self.sequence_parallel: assert self.seq_length % self.config.sep_parallel_degree == 0 @@ -1027,7 +1032,6 @@ def forward( value_states = paddle.concat([past_key_value[1], value_states], axis=1) past_key_value = (key_states, value_states) if use_cache else None - if self.kv_indices is not None: key_states = paddle.index_select(key_states, self.kv_indices, axis=2) value_states = paddle.index_select(value_states, self.kv_indices, axis=2) @@ -1036,7 +1040,7 @@ def forward( # repeat k/v heads if n_kv_heads < n_heads # paddle version > 2.6 or develop support flash-attn with gqa/mqa paddle_version = float(paddle.__version__[:3]) - if (paddle_version != 0.0) and (paddle_version <= 2.6): + if not self.config.use_flash_attention or ((paddle_version != 0.0) and (paddle_version <= 2.6)): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -1560,7 +1564,6 @@ def forward( else: attention_mask = attention_mask.astype("bool") hidden_states = inputs_embeds - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None diff --git a/paddlenlp/transformers/llama/tokenizer.py b/paddlenlp/transformers/llama/tokenizer.py index 4efaa48f797c..6f19530c05cb 100644 --- a/paddlenlp/transformers/llama/tokenizer.py +++ b/paddlenlp/transformers/llama/tokenizer.py @@ -24,7 +24,7 @@ from .. import PretrainedTokenizer from ..tokenizer_utils_base import BatchEncoding, EncodedInput, PaddingStrategy -__all__ = ["LlamaTokenizer"] +__all__ = ["LlamaTokenizer", "Llama3Tokenizer"] class LlamaTokenizer(PretrainedTokenizer): @@ -199,6 +199,7 @@ def create_token_type_ids_from_sequences( """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make use of token type ids, therefore a list of zeros is returned. + Args: token_ids_0 (`List[int]`): List of IDs. @@ -270,3 +271,289 @@ def _pad( constant_values=0, ) return encoded_inputs + + +"""Copied Tokenization classes for QWen.""" + +import base64 +import unicodedata +from typing import Collection, Dict, List, Optional, Set, Tuple, Union + +from ...utils.import_utils import is_tiktoken_available +from .. import PretrainedTokenizer +from ..tokenizer_utils_base import ( + AddedToken, + BatchEncoding, + EncodedInput, + PaddingStrategy, +) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PAT_STR = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +BEGINOFTEXT = "<|begin_of_text|>" +ENDOFTEXT = "<|end_of_text|>" +IMSTART = "<|start_header_id|>" +IMEND = "<|end_header_id|>" +# as the default behavior is changed to allow special tokens in +# regular texts, the surface forms of special tokens need to be +# as different as possible to minimize the impact +EXTRAS = tuple((f"<|reserved_special_token_{i}|>" for i in range(250))) +SPECIAL_TOKENS = (BEGINOFTEXT, ENDOFTEXT) + EXTRAS[0:4] + (IMSTART, IMEND) + EXTRAS[4:] + +tiktoken = None + + +def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: + with open(tiktoken_bpe_file, "rb") as f: + contents = f.read() + return { + base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) + } + + +class Llama3Tokenizer(PretrainedTokenizer): + """QWen tokenizer.""" + + model_input_names = ["input_ids", "attention_mask", "position_ids"] + resource_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + errors="replace", + padding_side="left", + **kwargs, + ): + super().__init__(**kwargs) + if not is_tiktoken_available(): + raise ValueError("tiktoken is not installed, please install it use: pip install tiktoken") + + import tiktoken as tk + + tiktoken = tk + + self.errors = errors # how to handle errors in decoding + + self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int] + self.special_tokens = { + token: index for index, token in enumerate(SPECIAL_TOKENS, start=len(self.mergeable_ranks)) + } + enc = tiktoken.Encoding( + "Llama3", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + assert ( + len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab + ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" + + self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str] + self.decoder.update({v: k for k, v in self.special_tokens.items()}) + + self.tokenizer = enc # type: tiktoken.Encoding + + self.eod_id = self.special_tokens[ENDOFTEXT] + self.start_header_id = self.special_tokens[IMSTART] + self.end_header_id = self.special_tokens[IMEND] + + if "pad_token_id" in kwargs: + self.pad_token_id = kwargs["pad_token_id"] + if "eos_token_id" in kwargs: + self.eos_token_id = kwargs["eos_token_id"] + + def __len__(self) -> int: + return self.tokenizer.n_vocab + + def get_vocab(self) -> Dict[bytes, int]: + return self.mergeable_ranks + + def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]: + ids = [] + if isinstance(tokens, (str, bytes)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.mergeable_ranks.get(tokens) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.mergeable_ranks.get(token)) + return ids + + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + if not special_tokens and new_tokens: + raise ValueError("Adding regular tokens is not supported") + for token in new_tokens: + surface_form = token.content if isinstance(token, AddedToken) else token + if surface_form not in SPECIAL_TOKENS: + raise ValueError("Adding unknown special tokens is not supported") + return 0 + + def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: + """ + Save only the vocabulary of the tokenizer (vocabulary). + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + file_path = os.path.join(save_directory, "tokenizer.model") + with open(file_path, "w", encoding="utf8") as w: + for k, v in self.mergeable_ranks.items(): + line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" + w.write(line) + return (file_path,) + + def tokenize( + self, + text: str, + allowed_special: Union[Set, str] = "all", + disallowed_special: Union[Collection, str] = (), + **kwargs, + ) -> List[Union[bytes, str]]: + """ + Converts a string in a sequence of tokens. + + Args: + text (`str`): + The sequence to be encoded. + allowed_special (`Literal["all"]` or `set`): + The surface forms of the tokens to be encoded as special tokens in regular texts. + Default to "all". + disallowed_special (`Literal["all"]` or `Collection`): + The surface forms of the tokens that should not be in regular texts and trigger errors. + Default to an empty tuple. + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific encode method. + + Returns: + `List[bytes|str]`: The list of tokens. + """ + tokens = [] + text = unicodedata.normalize("NFC", text) + + # this implementation takes a detour: text -> token id -> token surface forms + for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special): + tokens.append(self.decoder[t]) + return tokens + + def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: + """ + Converts a sequence of tokens in a single string. + """ + text = "" + temp = b"" + for t in tokens: + if isinstance(t, str): + if temp: + text += temp.decode("utf-8", errors=self.errors) + temp = b"" + text += t + elif isinstance(t, bytes): + temp += t + else: + raise TypeError("token should only be of type types or str") + if temp: + text += temp.decode("utf-8", errors=self.errors) + return text + + @property + def vocab_size(self): + return self.tokenizer.n_vocab + + def _convert_id_to_token(self, index: int) -> Union[bytes, str]: + """Converts an id to a token, special tokens included""" + if index in self.decoder: + return self.decoder[index] + raise ValueError("unknown ids") + + def _convert_token_to_id(self, token: Union[bytes, str]) -> int: + """Converts a token to an id using the vocab, special tokens included""" + if token in self.special_tokens: + return self.special_tokens[token] + if token in self.mergeable_ranks: + return self.mergeable_ranks[token] + raise ValueError("unknown token") + + def _tokenize(self, text: str, **kwargs): + """ + Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. + """ + raise NotImplementedError + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + errors: str = None, + **kwargs, + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + if skip_special_tokens: + token_ids = [i for i in token_ids if i < self.eod_id] + return self.tokenizer.decode(token_ids, errors=errors or self.errors) + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + + # attention_mask shape [1,seq_len,seq_len] + if "attention_mask" in encoded_inputs and len(np.shape(encoded_inputs["attention_mask"])) > 2: + attention_mask = encoded_inputs["attention_mask"] + encoded_inputs.pop("attention_mask") + else: + attention_mask = None + + required_input = encoded_inputs[self.model_input_names[0]] + encoded_inputs = super()._pad( + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + ) + if attention_mask is not None and len(np.shape(attention_mask)) > 2: + encoded_inputs["attention_mask"] = attention_mask + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + if needs_to_be_padded: + difference = max_length - len(required_input) + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = np.pad( + encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode="constant", + constant_values=0, + ) + return encoded_inputs