diff --git a/llm/README.md b/llm/README.md index adeae7b03edd..745fe0ea6ca4 100644 --- a/llm/README.md +++ b/llm/README.md @@ -23,6 +23,7 @@ | [LLaMA](./config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [Qwen](./config/qwen) | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | | [Mixtral](./config/mixtral) | ✅ | ✅ | ✅ | ❌ | 🚧 | 🚧 | 🚧 | 🚧 | +| [Mistral](./config/mistral) | ❌ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | | [Baichuan/Baichuan2](./config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | | [ChatGLM-6B](./config/chatglm) | ❌ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ❌ | | [ChatGLM2/ChatGLM3](./config/chatglm2) | ❌ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ | diff --git a/llm/config/mistral/README.md b/llm/config/mistral/README.md new file mode 100644 index 000000000000..a20090e4688c --- /dev/null +++ b/llm/config/mistral/README.md @@ -0,0 +1,20 @@ +# Mistral + +## 1. 模型介绍 + +**支持模型权重:** + +| Model | +|--------------------------------------| +| mistralai/Mistral-7B-Instruct-v0.3 | +| mistralai/Mistral-7B-v0.1 | + + + +使用方法: + +```python +from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer +model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") +``` diff --git a/llm/config/mistral/dpo_argument.json b/llm/config/mistral/dpo_argument.json new file mode 100644 index 000000000000..11480b5fd659 --- /dev/null +++ b/llm/config/mistral/dpo_argument.json @@ -0,0 +1,38 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3", + "train_dataset_path": "./dpo_data/train.jsonl", + "dev_dataset_path": "./dpo_data/train.jsonl", + "output_dir": "./checkpoints/dpo_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 8, + "per_device_eval_batch_size": 1, + "num_train_epochs": 1, + "max_steps": 100, + "learning_rate": 1e-06, + "warmup_steps": 10, + "logging_steps": 1, + "evaluation_strategy": "steps", + "save_strategy": "steps", + "eval_steps": 100, + "save_steps": 500, + "max_seq_len": 4096, + "max_prompt_len": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "tensor_parallel_degree": 8, + "sharding_parallel_degree": 1, + "sharding": "stage1", + "use_flash_attention": true, + "recompute": false, + "recompute_granularity": "full", + "dpo_beta": 0.1, + "benchmark": false, + "dpo_loss_type": "sigmoid", + "dpo_label_smoothing": 0.0, + "unified_checkpoint": true, + "autotuner_benchmark":false + } diff --git a/llm/config/mistral/lora_argument.json b/llm/config/mistral/lora_argument.json new file mode 100644 index 000000000000..a04db794f356 --- /dev/null +++ b/llm/config/mistral/lora_argument.json @@ -0,0 +1,32 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/mistral_lora_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "recompute": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "use_flash_attention": true, + "zero_padding": true, + "lora": true + } diff --git a/llm/config/mistral/pt_argument.json b/llm/config/mistral/pt_argument.json new file mode 100644 index 000000000000..b3728227e5ca --- /dev/null +++ b/llm/config/mistral/pt_argument.json @@ -0,0 +1,30 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/mistral_pt_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-02, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": false, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "prefix_tuning": true + } diff --git a/llm/config/mistral/sft_argument.json b/llm/config/mistral/sft_argument.json new file mode 100644 index 000000000000..3532e86404c5 --- /dev/null +++ b/llm/config/mistral/sft_argument.json @@ -0,0 +1,30 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/mistral_sft_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-05, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "zero_padding": true, + "tensor_parallel_degree": 8, + "pipeline_parallel_degree": 1 + } diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 8df3705fe335..de31240d2ae3 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -338,11 +338,11 @@ def neft_post_hook(module, input, output): if data_args.zero_padding: if ( - model.base_model_prefix not in ["llama", "bloom", "chatglm", "chatglm_v2", "qwen"] + model.base_model_prefix not in ["llama", "bloom", "chatglm", "chatglm_v2", "qwen", "mistral"] and training_args.pipeline_parallel_degree < 1 ): raise NotImplementedError( - "Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far." + "Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM, QWen and Mistral so far." ) train_ds = ( train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)) diff --git a/llm/utils/data.py b/llm/utils/data.py index a97e08d926c7..f9c48f8a58a7 100644 --- a/llm/utils/data.py +++ b/llm/utils/data.py @@ -50,6 +50,7 @@ def get_convert_example(model): "opt", "qwen", "mixtral", + "mistral", "gemma", "qwen2", "qwen2_moe", diff --git a/llm/utils/utils.py b/llm/utils/utils.py index 2f51711b496b..65254f002167 100644 --- a/llm/utils/utils.py +++ b/llm/utils/utils.py @@ -85,6 +85,14 @@ def get_prefix_tuning_params(model): hidden_size = model.config.hidden_size postprocess_past_key_value = llama_postprocess_past_key_value multi_query_group_num = None + elif model.base_model_prefix == "mistral": + from paddlenlp.peft.prefix import mistral_postprocess_past_key_value + + num_attention_heads = model.config.num_attention_heads + num_hidden_layers = model.config.num_hidden_layers + hidden_size = model.config.hidden_size + postprocess_past_key_value = mistral_postprocess_past_key_value + multi_query_group_num = model.config.num_key_value_heads elif model.base_model_prefix == "qwen": from paddlenlp.peft.prefix import qwen_postprocess_past_key_value @@ -190,6 +198,17 @@ def get_lora_target_modules(model): ".*w2.*", ".*w3.*", ] + elif model.base_model_prefix == "mistral": + target_modules = [ + ".*q_proj.*", + ".*k_proj.*", + ".*v_proj.*", + ".*o_proj.*", + ".*gate.*", + ".*w1.*", + ".*w2.*", + ".*w3.*", + ] elif model.base_model_prefix == "qwen2_moe": target_modules = [ ".*q_proj.*", @@ -279,6 +298,7 @@ def prediction_step( )[0] all_preds = [] for pred_tokens in generated_tokens: + pred_tokens = pred_tokens.numpy() pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist() all_preds.append(pred_tokens) max_pred_length = max([len(x) for x in all_preds]) diff --git a/paddlenlp/peft/prefix/__init__.py b/paddlenlp/peft/prefix/__init__.py index 90edebf53b66..c8bd6e6f07a5 100644 --- a/paddlenlp/peft/prefix/__init__.py +++ b/paddlenlp/peft/prefix/__init__.py @@ -18,5 +18,6 @@ bloom_postprocess_past_key_value, chatglm_postprocess_past_key_value, llama_postprocess_past_key_value, + mistral_postprocess_past_key_value, qwen_postprocess_past_key_value, ) diff --git a/paddlenlp/peft/prefix/utils.py b/paddlenlp/peft/prefix/utils.py index 684245c2f380..50584684568c 100644 --- a/paddlenlp/peft/prefix/utils.py +++ b/paddlenlp/peft/prefix/utils.py @@ -38,6 +38,13 @@ def llama_postprocess_past_key_value(past_key_values): return tuple(zip(keys, values)) +def mistral_postprocess_past_key_value(past_key_values): + # (layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim)*2 + keys, values = paddle.transpose(past_key_values, perm=[2, 0, 3, 1, 4]).split(2) + + return tuple(zip(keys, values)) + + def qwen_postprocess_past_key_value(past_key_values): # (layer_num, bs, prefixlen, head_num/tensor_parallel_degree, head_dim)*2 keys, values = paddle.transpose(past_key_values, perm=[2, 0, 1, 3, 4]).split(2) diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 6be82645192d..0e35faa3c4ce 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -285,6 +285,8 @@ from .rw.modeling import * from .rw.configuration import * from .rw.tokenizer import * +from .mistral.modeling import * +from .mistral.configuration import * from .qwen import * from .mixtral.modeling import * from .mixtral.configuration import * diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index f22126f76c8a..d702567b4df9 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -117,6 +117,7 @@ ("Blip", "blip"), ("Bloom", "bloom"), ("QWen", "qwen"), + ("Mistral", "mistral"), ("Mixtral", "mixtral"), ("Qwen2", "qwen2"), ("Qwen2Moe", "qwen2_moe"), diff --git a/paddlenlp/transformers/mistral/__init__.py b/paddlenlp/transformers/mistral/__init__.py new file mode 100644 index 000000000000..0b41cc3d8c54 --- /dev/null +++ b/paddlenlp/transformers/mistral/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .configuration import MistralConfig +from .modeling import MistralForCausalLM diff --git a/paddlenlp/transformers/mistral/configuration.py b/paddlenlp/transformers/mistral/configuration.py new file mode 100644 index 000000000000..11237e5c840a --- /dev/null +++ b/paddlenlp/transformers/mistral/configuration.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Mistral model configuration""" + +from ..configuration_utils import PretrainedConfig + + +class MistralConfig(PretrainedConfig): + model_type = "mistral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/paddlenlp/transformers/mistral/modeling.py b/paddlenlp/transformers/mistral/modeling.py new file mode 100644 index 000000000000..f973390f0c5c --- /dev/null +++ b/paddlenlp/transformers/mistral/modeling.py @@ -0,0 +1,962 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from functools import partial +from typing import List, Optional, Tuple, Union + +import paddle +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import recompute +from paddle.nn import CrossEntropyLoss + +from paddlenlp.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddlenlp.utils.log import logger + +from ..activations import ACT2FN +from ..model_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, +) +from ..model_utils import PretrainedModel +from .configuration import MistralConfig + + +def _make_causal_mask( + input_ids_shape: paddle.shape, + dtype: paddle.dtype, + past_key_values_length: int = 0, +): + """ + Make causal mask used for sliding window attention + """ + bsz, tgt_len = input_ids_shape + + tensor = paddle.full( + (tgt_len, tgt_len), + fill_value=1, + ) + mask = paddle.tril(tensor, diagonal=0) + mask = paddle.log(mask).astype(dtype) + + if past_key_values_length > 0: + mask = paddle.concat([paddle.zeros([tgt_len, past_key_values_length], dtype=dtype), mask], axis=-1) + return mask[None, None, :, :].expand([bsz, 1, tgt_len, tgt_len + past_key_values_length]) + + +def _expand_mask(mask: paddle.Tensor, dtype: paddle.dtype, tgt_len: Optional[int] = None): + expanded_mask = mask + if len(mask.shape) == 2: + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.shape + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).astype(dtype) + elif len(mask.shape) == 3: + """ + Expands attention_mask from `[bsz, tgt_seq_len, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + expanded_mask = mask.unsqueeze(1).astype(dtype) + + inverted_mask = 1.0 - expanded_mask + + return paddle.where(inverted_mask > 0.5, paddle.full_like(inverted_mask, paddle.finfo(dtype).min), inverted_mask) + + +class MistralRMSNorm(nn.Layer): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = paddle.create_parameter( + shape=[hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.astype(paddle.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * paddle.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.astype(input_dtype) + + +class MistralRotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = 1.0 / (self.base ** (paddle.arange(0, self.dim, 2).astype("float32") / self.dim)) + + # Build here to make `paddle.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, dtype=paddle.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len, dtype): + self.max_seq_len_cached = seq_len + t = paddle.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype) + + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = paddle.concat((freqs, freqs), axis=-1) + self.cos_cached = emb.cos().astype(dtype) + self.sin_cached = emb.sin().astype(dtype) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].astype(dtype=x.dtype), + self.sin_cached[:seq_len].astype(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + if config.tensor_parallel_degree > 1: + self.gate_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + + self.down_proj = mpu.RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + This is the equivalent of paddle.repeat_interleave(x, axis=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand([batch, num_key_value_heads, n_rep, slen, head_dim]) + return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim]) + + +class MistralAttention(nn.Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MistralConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + if config.tensor_parallel_degree > 1: + if self.num_key_value_heads % config.tensor_parallel_degree != 0: + raise ValueError( + f"num_key_value_heads must be divisible by tensor_parallel_degree (got `num_key_value_heads`: {self.num_key_value_heads}" + f" and `tensor_parallel_degree`: {config.tensor_parallel_degree})." + ) + + self.q_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.head_dim, + has_bias=False, + gather_output=False, + ) + self.k_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + has_bias=False, + gather_output=False, + ) + self.v_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + has_bias=False, + gather_output=False, + ) + else: + self.q_proj = nn.Linear( + self.hidden_size, + self.num_heads * self.head_dim, + bias_attr=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + + if config.tensor_parallel_degree > 1: + self.o_proj = mpu.RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + has_bias=False, + input_is_parallel=True, + ) + self.num_heads = self.num_heads // config.tensor_parallel_degree + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + else: + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, + self.hidden_size, + bias_attr=False, + ) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3]) + key_states = key_states.reshape([bsz, q_len, self.num_key_value_heads, self.head_dim]).transpose([0, 2, 1, 3]) + value_states = value_states.reshape([bsz, q_len, self.num_key_value_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=2) + value_states = paddle.concat([past_key_value[1], value_states], axis=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if not self.config.use_flash_attention: + attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) / math.sqrt(self.head_dim) + + if attn_weights.shape != [bsz, self.num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of size {[bsz, self.num_heads, q_len, kv_seq_len]}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is not None: + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of size {[bsz, 1, q_len, kv_seq_len]}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype( + query_states.dtype + ) + attn_output = paddle.matmul(attn_weights, value_states) + else: + query_states = query_states.transpose([0, 2, 1, 3]) + key_states = key_states.transpose([0, 2, 1, 3]) + value_states = value_states.transpose([0, 2, 1, 3]) + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + ) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + if attn_output.shape != [bsz, self.num_heads, q_len, self.head_dim]: + raise ValueError( + f"`attn_output` should be of size {[bsz, self.num_heads, q_len, self.head_dim]}, but is" + f" {attn_output.shape}" + ) + + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape([bsz, q_len, self.num_heads * self.head_dim]) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MistralDecoderLayer(nn.Layer): + def __init__(self, config: MistralConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = MistralAttention(config=config) + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MistralPreTrainedModel(PretrainedModel): + config_class = MistralConfig + base_model_prefix = "mistral" + + @classmethod + def _get_name_mappings(cls, config: MistralConfig) -> List[StateDictNameMapping]: + mappings: List[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "mistral." + mapping[1] + + if "MistralModel" not in config.architectures: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: MistralConfig, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + def _init_weights(self, layer): + """Initialization hook""" + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.ColumnParallelLinear, + mpu.RowParallelLinear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.llama.config.initializer_range, + shape=layer.weight.shape, + ) + ) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, MistralMLP): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.down_proj.weight.scale_(factor) + if isinstance(layer, MistralAttention): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + + +class MistralModel(MistralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if config.tensor_parallel_degree > 1: + self.embed_tokens = mpu.VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + ) + self.layers = nn.LayerList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.enable_recompute = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = paddle.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=paddle.int64 + ) + position_ids = position_ids.unsqueeze(0).expand((batch_size, seq_length)) + else: + position_ids = position_ids.reshape([-1, seq_length]).astype("int64") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + if self.enable_recompute and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and has_gradient: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = recompute( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def parallel_matmul(x: paddle.Tensor, y: paddle.Tensor, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=False) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = paddle.matmul(x, y, transpose_y=False) + return logits + + +class MistralLMHead(nn.Layer): + def __init__(self, config: MistralConfig): + super(MistralLMHead, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if self.weight.is_distributed: + self.weight.split_axis = 1 + + def forward(self, hidden_states, tensor_parallel_output=None): + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + return logits + + +class MistralPretrainingCriterion(paddle.nn.Layer): + """ + Criterion for Llama. + It calculates the final loss. + """ + + def __init__(self, config): + + super(MistralPretrainingCriterion, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + # skip ignore_index which loss == 0 + masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32") + loss = paddle.mean(masked_lm_loss) + + return loss + + +class MistralForCausalLM(MistralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.mistral = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = MistralLMHead(config) + self.criterion = MistralPretrainingCriterion(config) + + def get_input_embeddings(self): + return self.mistral.embed_tokens + + def set_input_embeddings(self, value): + self.mistral.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.mistral = decoder + + def get_decoder(self): + return self.mistral + + def prepare_inputs_for_generation( + self, input_ids, use_cache=False, past_key_values=None, inputs_embeds=None, **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + attention_mask = model_kwargs.pop("attention_mask", None) + + if attention_mask is not None and len(attention_mask.shape) == 2: + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.mistral( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.astype("float32") + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/transformers/mistral/__init__.py b/tests/transformers/mistral/__init__.py new file mode 100644 index 000000000000..595add0aed9e --- /dev/null +++ b/tests/transformers/mistral/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/mistral/test_modeling.py b/tests/transformers/mistral/test_modeling.py new file mode 100644 index 000000000000..a177ba5c43fa --- /dev/null +++ b/tests/transformers/mistral/test_modeling.py @@ -0,0 +1,317 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest + +import paddle + +from paddlenlp.transformers import MistralConfig, MistralForCausalLM, MistralModel +from tests.transformers.test_configuration_common import ConfigTester +from tests.transformers.test_generation_utils import GenerationTesterMixin +from tests.transformers.test_modeling_common import ( + ModelTesterMixin, + ids_tensor, + random_attention_mask, +) + + +class MistralModelTester: + def __init__( + self, + parent, + vocab_size=32000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + masked_softmax_fusion=True, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + is_training=True, + use_cache=False, + bos_token_id=1, + eos_token_id=2, + pad_token_id=3, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.0, + attention_dropout=0.0, + attention_softmax_in_fp32=True, + pretraining_tp=1, # TP rank used when training with megatron + dtype="bfloat16", + slow_but_exact=False, + batch_size: int = 2, + seq_length: int = 10, + type_sequence_label_size=2, + activation_function="gelu", + num_labels=3, + num_choices=4, + scope=None, + dropout=0.56, + use_input_mask: bool = False, + use_labels: bool = False, + return_dict=False, + ): + self.parent: MistralModelTest = parent + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.masked_softmax_fusion = masked_softmax_fusion + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.is_training = is_training + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.pretraining_tp = pretraining_tp + self.dtype = dtype + self.slow_but_exact = slow_but_exact + + self.batch_size = batch_size + self.seq_length = seq_length + self.type_sequence_label_size = type_sequence_label_size + self.activation_function = activation_function + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.dropout = dropout + + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.return_dict = return_dict + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype=paddle.int64) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self) -> MistralConfig: + return MistralConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + masked_softmax_fusion=self.masked_softmax_fusion, + layer_norm_epsilon=self.layer_norm_epsilon, + initializer_range=self.initializer_range, + use_cache=self.use_cache, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, + hidden_dropout=self.hidden_dropout, + attention_dropout=self.attention_dropout, + attention_softmax_in_fp32=self.attention_softmax_in_fp32, + pretraining_tp=self.pretraining_tp, + dtype=self.dtype, + slow_but_exact=self.slow_but_exact, + activation_function=self.activation_function, + ) + + def create_and_check_model( + self, config: MistralConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MistralModel(config) + model.eval() + result = model(input_ids) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + + def create_and_check_model_attention_mask( + self, config: MistralConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MistralModel(config) + model.eval() + attn_mask_2d = random_attention_mask([self.batch_size, self.seq_length]) + result_2d = model(input_ids, attention_mask=attn_mask_2d)[0] + batch, seq_length = input_ids.shape + causal_mask = paddle.tril(paddle.ones((batch, seq_length, seq_length), dtype=attn_mask_2d.dtype)) + attn_mask_3d = causal_mask & attn_mask_2d.unsqueeze(-1) + result_3d = model(input_ids, attention_mask=attn_mask_3d)[0] + attn_mask_4d = attn_mask_3d.unsqueeze(1) + result_4d = model(input_ids, attention_mask=attn_mask_4d)[0] + result_no_attention_mask = model(input_ids, attention_mask=None)[0] + # Assert non-padding tokens have the same logits with different attention_mask shape + self.parent.assertTrue((result_2d[attn_mask_2d] == result_3d[attn_mask_2d]).all()) + self.parent.assertTrue((result_2d[attn_mask_2d] == result_4d[attn_mask_2d]).all()) + self.parent.assertTrue((result_2d[attn_mask_2d] == result_no_attention_mask[attn_mask_2d]).all()) + + def create_and_check_model_past_large_inputs( + self, + config: MistralConfig, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MistralModel(config) + model.eval() + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True, return_dict=self.return_dict) + past_key_values = outputs.past_key_values if self.return_dict else outputs[2] + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), self.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = paddle.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = paddle.concat([input_mask, next_mask], axis=-1) + + outputs = model( + next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True, return_dict=self.return_dict + ) + + output_from_no_past = outputs[2][0] + + outputs = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + return_dict=self.return_dict, + ) + + output_from_past = outputs[2][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(paddle.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args): + model = MistralForCausalLM(config) + model.eval() + + result = model( + input_ids, + use_cache=True, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertIsInstance(result[0].item(), float) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size]) + else: + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + + def check_model_position_ids(self, config, input_ids, input_mask, *args): + model = MistralForCausalLM(config) + model.eval() + + result_no_position_id = model( + input_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + batch_size, seq_len = input_ids.shape + position_ids = paddle.arange(seq_len).expand((batch_size, seq_len)) + result_position_id = model( + input_ids, + position_ids=position_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertTrue((result_position_id[1] == result_no_position_id[1]).all()) + else: + self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all()) + + +class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + base_model_class = MistralModel + return_dict = False + use_labels = False + use_test_model_name_list = False + + all_model_classes = (MistralModel, MistralForCausalLM) + all_generative_model_classes = {MistralForCausalLM: (MistralModel, "Mistral")} + + def setUp(self): + super().setUp() + + self.model_tester = MistralModelTester(self) + self.config_tester = ConfigTester(self, config_class=MistralConfig, vocab_size=256, hidden_size=24) + + def _get_input_ids_and_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_ids = inputs_dict[self.input_name] + attention_mask = paddle.ones_like(input_ids, dtype=paddle.int64) + + max_batch_size = 2 + sequence_length = input_ids.shape[-1] // 2 + input_ids = input_ids[:max_batch_size, :sequence_length] + attention_mask = attention_mask[:max_batch_size, :sequence_length] + max_length = 3 + + return config, input_ids, attention_mask, max_length + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_attention_mask(*config_and_inputs) + + def test_model_position_ids(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_model_position_ids(*config_and_inputs) + + def test_Mistral_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + + +if __name__ == "__main__": + unittest.main()