diff --git a/paddlenlp/transformers/bart/modeling.py b/paddlenlp/transformers/bart/modeling.py index 341a75484c08..3112487b3658 100644 --- a/paddlenlp/transformers/bart/modeling.py +++ b/paddlenlp/transformers/bart/modeling.py @@ -12,20 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -import numpy as np +import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F -import paddle.tensor as tensor -from paddle.nn import Layer, Embedding +from paddle.nn import Embedding, Layer +from ...utils.log import logger from .. import PretrainedModel, register_base_model from ..model_outputs import ( ModelOutput, - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, Seq2SeqQuestionAnsweringModelOutput, @@ -200,6 +197,7 @@ def forward( self, input_ids=None, attention_mask=None, + inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -213,6 +211,8 @@ def forward( See :class:`BartModel`. attention_mask (Tensor, optional): See :class:`BartModel`. + inputs_embeds (Tensor, optional): + See :class:`BartModel`. output_attentions (bool, optional): See :class:`BartModel`. output_hidden_states (bool, optional): @@ -230,15 +230,25 @@ def forward( Its data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size]. """ - if input_ids is None: - raise ValueError("Input_ids cannot be None.") - inputs_embeds = self.embed_tokens(input_ids) - inputs_embed_pos = self.encoder_embed_positions(paddle.shape(input_ids)) + if input_ids is None and inputs_embeds is None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + inputs_shape = paddle.shape(input_ids) + input_ids = input_ids.reshape((-1, inputs_shape[-1])) + elif inputs_embeds is not None: + inputs_shape = paddle.shape(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embed_pos = self.encoder_embed_positions(inputs_shape) hidden_states = inputs_embeds + inputs_embed_pos hidden_states = self.encoder_layernorm_embedding(hidden_states) encoder_input = self.encoder_dropout(hidden_states) - if attention_mask is None: + if attention_mask is None and input_ids is not None: attention_mask = ( paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4 ) @@ -308,6 +318,7 @@ def forward( decoder_attention_mask=None, encoder_output=None, memory_mask=None, + decoder_inputs_embeds=None, cache=None, output_attentions=False, output_hidden_states=False, @@ -325,6 +336,8 @@ def forward( See :class:`BartModel`. memory_mask (Tensor, optional): See :class:`BartModel`. + decoder_inputs_embeds (Tensor, optional): + See :class:`BartModel`. cache (Tensor, optional): See :class:`BartModel`. output_attentions (bool, optional): @@ -344,16 +357,28 @@ def forward( Its data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size]. """ + # retrieve input_ids and inputs_embeds + if decoder_input_ids is not None and decoder_inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif decoder_input_ids is not None: + inputs_shape = paddle.shape(decoder_input_ids) + decoder_input_ids = decoder_input_ids.reshape((-1, inputs_shape[-1])) + elif decoder_inputs_embeds is not None: + inputs_shape = paddle.shape(decoder_inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if decoder_attention_mask is None: - decoder_length = paddle.shape(decoder_input_ids)[-1] + decoder_length = inputs_shape[-1] decoder_attention_mask = paddle.tensor.triu( (paddle.full((decoder_length, decoder_length), -np.inf, dtype=paddle.get_default_dtype())), 1 ) - decoder_inputs_embeds = self.embed_tokens(decoder_input_ids) + + if decoder_inputs_embeds is None: + decoder_inputs_embeds = self.embed_tokens(decoder_input_ids) + past_key_values_length = paddle.shape(cache[0][0].k)[2] if cache is not None else 0 - decoder_inputs_embed_pos = self.decoder_embed_positions( - paddle.shape(decoder_input_ids), past_key_values_length - ) + decoder_inputs_embed_pos = self.decoder_embed_positions(inputs_shape, past_key_values_length) hidden_states = decoder_inputs_embeds + decoder_inputs_embed_pos hidden_states = self.decoder_layernorm_embedding(hidden_states) decoder_input = self.decoder_dropout(hidden_states) @@ -515,11 +540,13 @@ def set_input_embeddings(self, value): def forward( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, + inputs_embeds=None, + decoder_inputs_embeds=None, use_cache=False, cache=None, output_attentions=False, @@ -530,7 +557,7 @@ def forward( The BartModel forward method, overrides the `__call__()` special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. @@ -560,6 +587,19 @@ def forward( For all element in the tuple, its data type should be float32 and its shape is [`batch_size, sequence_length, hidden_size`]. `attentions` is attentions of all layers of in the Transformer encoder. The length of `attentions` is `num_hidden_layers`. For all element in the tuple, its data type should be float32 and its shape is [`batch_size, num_attention_heads, sequence_length, sequence_length`]. + inputs_embeds (Tensor, optional): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation + of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over + how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + Default to None. + decoder_inputs_embeds (Tensor, optional): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation of shape `(batch_size, target_sequence_length, hidden_size)`. If `cache` is used, + optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). + This is useful if you want more control over how to convert `decoder_input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. Default to None. + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. use_cache (bool, optional): Whether or not to use cache. Defaults to `False`. If set to `True`, key value states will be returned and can be used to speed up decoding. @@ -601,35 +641,46 @@ def forward( """ # different to other models, Bart automatically creates decoder_input_ids from # inputBartForSequenceClassification_ids if no decoder_input_ids are provided - if input_ids is None and encoder_output is None: + if input_ids is None and inputs_embeds is None and encoder_output is None: raise ValueError("You have to specify either input_ids or encoder_output") - if decoder_input_ids is None: - assert input_ids is not None, "input_ids should be " "specified when generating decoder_input_ids" + + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) decoder_input_ids = shift_tokens_right(input_ids, self.decoder_start_token_id) - if attention_mask is None: - assert input_ids is not None, "input_ids should be " "specified when generating attention_mask" + if attention_mask is None and input_ids is not None: + # only generate attention_mask when input_ids is specified attention_mask = ( paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4 ) + if inputs_embeds is not None and input_ids is None and attention_mask is None: + logger.warning("provided inputs_embeds without attention_mask") # For 2D attention_mask from tokenizer elif attention_mask.ndim == 2: attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) attention_mask = (1.0 - attention_mask) * -1e4 attention_mask.stop_gradient = True + + input_type = type(decoder_input_ids) if decoder_input_ids is not None else type(decoder_inputs_embeds) if encoder_output is None: encoder_output = self.encoder( input_ids, attention_mask, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_output, ModelOutput): - if isinstance(encoder_output, type(decoder_input_ids)): + if isinstance(encoder_output, input_type): encoder_output = (encoder_output,) encoder_output = convert_encoder_output(encoder_output) - if isinstance(encoder_output, type(decoder_input_ids)): + if isinstance(encoder_output, input_type): encoder_last_hidden_state = encoder_output else: encoder_last_hidden_state = encoder_output[0] @@ -643,15 +694,16 @@ def forward( decoder_attention_mask, encoder_last_hidden_state, attention_mask, - cache, + cache=cache, + decoder_inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: - if isinstance(decoder_output, type(decoder_input_ids)): + if isinstance(decoder_output, input_type): decoder_output = (decoder_output,) - if isinstance(encoder_output, type(decoder_input_ids)): + if isinstance(encoder_output, input_type): encoder_output = (encoder_output,) return decoder_output + encoder_output @@ -722,11 +774,13 @@ def __init__(self, bart, num_labels=2, dropout=None): def forward( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, + inputs_embeds=None, + decoder_inputs_embeds=None, use_cache=False, cache=None, labels=None, @@ -738,7 +792,7 @@ def forward( The BartForSequenceClassification forward method, overrides the __call__() special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`BartModel`. attention_mask (Tensor, optional): See :class:`BartModel`. @@ -748,8 +802,12 @@ def forward( See :class:`BartModel`. encoder_output (Tensor, optonal): See :class:`BartModel`. - use_cache (bool, optional): + inputs_embeds (Tensor, optional): See :class:`BartModel`. + decoder_inputs_embeds (Tensor, optional): + See :class:`BartModel`. + use_cache (bool, optional): + See :class:`BartModel`. Forcely set to `False` when `labels` is provided that can save memory during training. cache (Tensor, optional): See :class:`BartModel`. labels (Tensor, optional): @@ -786,26 +844,41 @@ def forward( inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ + if labels is not None: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + + if input_ids is None and inputs_embeds is not None: + logger.warning( + f"{self.__class__.__name__} will not detect eos tokens in `inputs_embeds`. Results may be " + "unexpected if using eos tokens in conjunction with `inputs_embeds.`" + ) + outputs = self.bart( input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_output, - use_cache, - cache, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache=cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) output = outputs[0] - eos_mask = paddle.cast(input_ids == self.bart.config["eos_token_id"], dtype="int64") - if len(paddle.unique(paddle.sum(eos_mask, axis=1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - output_shape = paddle.shape(output) - # TODO(gongenlei): support bool tensor index - output = output.masked_select(eos_mask.unsqueeze(-1).astype("bool").tile([1, 1, output_shape[-1]])) + + if input_ids is not None: + eos_mask = paddle.cast(input_ids == self.bart.config["eos_token_id"], dtype="int64") + if len(paddle.unique(paddle.sum(eos_mask, axis=1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + + # TODO(gongenlei): support bool tensor index + output = output.masked_select(eos_mask.unsqueeze(-1).astype("bool").tile([1, 1, output_shape[-1]])) + sentence_representation = output.reshape([output_shape[0], -1, output_shape[-1]])[:, -1, :] logits = self.classifier(sentence_representation) @@ -858,11 +931,13 @@ def __init__(self, bart): def forward( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, + inputs_embeds=None, + decoder_inputs_embeds=None, use_cache=False, cache=None, start_positions=None, @@ -875,7 +950,7 @@ def forward( The BartForQuestionAnswering forward method, overrides the __call__() special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`BartModel`. attention_mask (Tensor, optional): See :class:`BartModel`. @@ -885,8 +960,12 @@ def forward( See :class:`BartModel`. encoder_output (Tensor, optonal): See :class:`BartModel`. - use_cache (bool, optional): + inputs_embeds (Tensor, optional): + See :class:`BartModel`. + decoder_inputs_embeds (Tensor, optional): See :class:`BartModel`. + use_cache (bool, optional): + See :class:`BartModel`. Forcely set to `False` when `start_positions` and `end_positions` are provided that can save memory during training. cache (Tensor, optional): See :class:`BartModel`. start_positions (Tensor, optional): @@ -939,14 +1018,22 @@ def forward( start_logits = outputs[0] end_logits =outputs[1] """ + if start_positions is not None and end_positions is not None: + logger.warning( + "The `use_cache` argument is changed to `False` since `start_positions` and `end_positions` are provided." + ) + use_cache = False + outputs = self.bart( input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_output, - use_cache, - cache, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache=cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -963,7 +1050,7 @@ def forward( if start_positions.ndim > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.shape[1] + ignored_index = paddle.shape(start_logits)[1] start_positions = start_positions.clip(0, ignored_index) end_positions = end_positions.clip(0, ignored_index) @@ -1047,11 +1134,13 @@ def prepare_faster_entry(self, kwargs): def forward( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, + inputs_embeds=None, + decoder_inputs_embeds=None, use_cache=False, cache=None, labels=None, @@ -1063,7 +1152,7 @@ def forward( The BartForConditionalGeneration forward method, overrides the __call__() special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`BartModel`. attention_mask (Tensor, optional): See :class:`BartModel`. @@ -1073,6 +1162,10 @@ def forward( See :class:`BartModel`. encoder_output (Tensor, optonal): See :class:`BartModel`. + inputs_embeds (Tensor, optional): + See :class:`BartModel`. + decoder_inputs_embeds (Tensor, optional): + See :class:`BartModel`. use_cache (bool, optional): See :class:`BartModel`. cache (Tensor, optional): @@ -1117,14 +1210,21 @@ def forward( outputs = model(**inputs) """ + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + outputs = self.bart( input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_output, - use_cache, - cache, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache=cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, diff --git a/paddlenlp/transformers/codegen/modeling.py b/paddlenlp/transformers/codegen/modeling.py index 9dad50524336..8fd71ecb3bc8 100644 --- a/paddlenlp/transformers/codegen/modeling.py +++ b/paddlenlp/transformers/codegen/modeling.py @@ -18,6 +18,7 @@ import paddle.nn.functional as F from paddle.nn import Layer +from ...utils.log import logger from .. import PretrainedModel, register_base_model from ..model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -415,6 +416,7 @@ def forward( token_type_ids=None, use_cache=False, cache=None, + inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -422,7 +424,7 @@ def forward( r""" The CodeGenModel forward method, overrides the `__call__()` special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. @@ -445,6 +447,11 @@ def forward( See `TransformerDecoder.gen_cache `__ for more details. It is only used for inference and should be None for training. Default to `None`. + inputs_embeds (Tensor, optional): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation + of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over + how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + Default to None. output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. Defaults to `False`. @@ -473,12 +480,17 @@ def forward( output = model(**inputs) """ - if input_ids is not None: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: input_shape = input_ids.shape - input_ids = input_ids.reshape(shape=(-1, input_shape[-1])) + input_ids = input_ids.reshape((-1, input_shape[-1])) batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + batch_size = inputs_embeds.shape[0] else: - raise ValueError("You have to specify input_ids") + raise ValueError("You have to specify either input_ids or inputs_embeds") if cache is None: past_length = 0 @@ -488,22 +500,27 @@ def forward( # Attention mask. if attention_mask is None: - assert input_ids is not None, "input_ids should be " "specified when generating attention_mask" - if batch_size == 1 and past_length != 0: - batch_size, seq_len = input_shape - attention_mask = paddle.zeros( - [batch_size, 1, 1, seq_len + past_length], dtype=paddle.get_default_dtype() - ) + if input_ids is not None: + if batch_size == 1 and past_length != 0: + batch_size, seq_len = input_shape + attention_mask = paddle.zeros( + [batch_size, 1, 1, seq_len + past_length], dtype=paddle.get_default_dtype() + ) + else: + attention_mask = ( + paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) + * -1e4 + ) else: - attention_mask = ( - paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) - * -1e4 + logger.warning( + "Provided inputs_embeds while attention_mask is None, attention weights will not be masked during forwarding." ) # For 2D attention_mask from tokenizer elif attention_mask.ndim == 2: attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) attention_mask = (1.0 - attention_mask) * -1e4 - attention_mask.stop_gradient = True + if attention_mask is not None: + attention_mask.stop_gradient = True # TODO: CodeGen Attention Mask is TOO confusion. # When it's 2D, it must be int and it's denoted by 1/0. # When using model.generate() without providing attention mask @@ -511,7 +528,8 @@ def forward( # the attention mask's dtype must be float and it's denoted by 0/-inf. # Moreover, cannot support 3D attention mask. - inputs_embeds = self.wte(input_ids) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) inputs_embeds = inputs_embeds + token_type_embeds @@ -641,6 +659,7 @@ def forward( use_cache=False, cache=None, labels=None, + inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -648,7 +667,7 @@ def forward( r""" The CodeGenForCausalLM forward method, overrides the __call__() special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`CodeGenModel`. attention_mask (Tensor, optional): See :class:`CodeGenModel`. @@ -660,12 +679,14 @@ def forward( Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., vocab_size]` + inputs_embeds (Tensor, optional): + See :class:`CodeGenModel`. output_attentions (bool, optional): - See :class: `CodeGenModel` + See :class: `CodeGenModel`. output_hidden_states (bool, optional): - See :class: `CodeGenModel` + See :class: `CodeGenModel`. return_dict (bool, optional): - See :class: `CodeGenModel` + See :class: `CodeGenModel`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithPastAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding @@ -691,6 +712,7 @@ def forward( token_type_ids=token_type_ids, use_cache=use_cache, cache=cache, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, diff --git a/paddlenlp/transformers/mbart/modeling.py b/paddlenlp/transformers/mbart/modeling.py index a56302f4eef1..38e12121fbb9 100644 --- a/paddlenlp/transformers/mbart/modeling.py +++ b/paddlenlp/transformers/mbart/modeling.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -import numpy as np +import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F -import paddle.tensor as tensor -from paddle.nn import Layer, Embedding +from paddle.nn import Embedding, Layer +from ...utils.log import logger from .. import PretrainedModel, register_base_model from ..model_outputs import ( ModelOutput, - BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, Seq2SeqQuestionAnsweringModelOutput, @@ -265,6 +263,7 @@ def forward( self, input_ids=None, attention_mask=None, + inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -278,21 +277,42 @@ def forward( See :class:`MBartModel`. attention_mask (Tensor, optional): See :class:`MBartModel`. + input_embeds (Tensor, optional): + See :class:`MBartModel`. + output_attentions (bool, optional): + See :class:`MBartModel`. + output_hidden_states (bool, optional): + See :class:`MBartModel`. + return_dict (bool, optional): + See :class:`MBartModel`. Returns: - Tensor: Returns tensor `encoder_output`, which is the output at the last layer of the model. + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` if + `return_dict=True`. Otherwise it returns a tuple of tensors corresponding + to ordered and not None (depending on the input arguments) fields of + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions`. + Especially, When `return_dict=output_hidden_states=output_attentions=False`, + returns tensor `encoder_outputs` which is the output at the last layer of the model. Its data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size]. - """ - if input_ids is None: - raise ValueError("Input_ids cannot be None.") - inputs_embeds = self.d_model**0.5 * self.embed_tokens(input_ids) - inputs_embed_pos = self.encoder_embed_positions(paddle.shape(input_ids)) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = paddle.shape(input_ids) + elif inputs_embeds is not None: + input_shape = paddle.shape(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.d_model**0.5 * self.embed_tokens(input_ids) + + inputs_embed_pos = self.encoder_embed_positions(input_shape) hidden_states = inputs_embeds + inputs_embed_pos hidden_states = self.encoder_layernorm_embedding(hidden_states) encoder_input = self.encoder_dropout(hidden_states) - if attention_mask is None: + if attention_mask is None and input_ids is not None: attention_mask = ( paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4 ) @@ -365,6 +385,7 @@ def forward( encoder_output=None, memory_mask=None, cache=None, + decoder_inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -383,20 +404,48 @@ def forward( See :class:`MBartModel`. cache (Tensor, optional): See :class:`MBartModel`. + decoder_inputs_embeds (Tensor, optional): + See :class:`MBartModel`. + output_attentions (bool, optional): + See :class:`MBartModel`. + output_hidden_states (bool, optional): + See :class:`MBartModel`. + return_dict (bool, optional): + See :class:`MBartModel`. Returns: - Tensor: Returns tensor `decoder_output`, which is the output at the last layer of the model. + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` if + `return_dict=True`. Otherwise it returns a tuple of tensors corresponding + to ordered and not None (depending on the input arguments) fields of + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions`. + Especially, When `return_dict=output_hidden_states=output_attentions=False`, + returns tensor `decoder_outputs` which is the output at the last layer of the model. Its data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size]. """ + # retrieve input_ids and inputs_embeds + if decoder_input_ids is not None and decoder_inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif decoder_input_ids is not None: + decoder_input_shape = paddle.shape(decoder_input_ids) + decoder_input_ids = decoder_input_ids.reshape((-1, decoder_input_shape[-1])) + elif decoder_inputs_embeds is not None: + decoder_input_shape = paddle.shape(decoder_inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if decoder_attention_mask is None: - decoder_length = paddle.shape(decoder_input_ids)[-1] + + decoder_length = decoder_input_shape[-1] decoder_attention_mask = paddle.tensor.triu( (paddle.full((decoder_length, decoder_length), -np.inf, dtype=paddle.get_default_dtype())), 1 ) - decoder_inputs_embeds = self.d_model**0.5 * self.embed_tokens(decoder_input_ids) + if decoder_inputs_embeds is None: + decoder_inputs_embeds = self.d_model**0.5 * self.embed_tokens(decoder_input_ids) + past_key_values_length = paddle.shape(cache[0][0].k)[2] if cache is not None else 0 - decoder_inputs_embed_pos = self.decoder_embed_positions(decoder_input_ids.shape, past_key_values_length) + decoder_inputs_embed_pos = self.decoder_embed_positions(decoder_input_shape, past_key_values_length) + hidden_states = decoder_inputs_embeds + decoder_inputs_embed_pos hidden_states = self.decoder_layernorm_embedding(hidden_states) decoder_input = self.decoder_dropout(hidden_states) @@ -558,13 +607,15 @@ def set_input_embeddings(self, value): def forward( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, use_cache=False, cache=None, + inputs_embeds=None, + decoder_inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -573,7 +624,7 @@ def forward( The MBartModel forward method, overrides the `__call__()` special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. @@ -603,6 +654,19 @@ def forward( For all element in the tuple, its data type should be float32 and its shape is [`batch_size, sequence_length, hidden_size`]. `attentions` is attentions of all layers of in the Transformer encoder. The length of `attentions` is `num_hidden_layers`. For all element in the tuple, its data type should be float32 and its shape is [`batch_size, num_attention_heads, sequence_length, sequence_length`]. + inputs_embeds (Tensor, optional): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation + of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over + how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + Default to None. + decoder_inputs_embeds (Tensor, optional): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation of shape `(batch_size, target_sequence_length, hidden_size)`. If `cache` is used, + optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). + This is useful if you want more control over how to convert `decoder_input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. Default to None. + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. use_cache (bool, optional): Whether or not to use cache. Defaults to `False`. If set to `True`, key value states will be returned and can be used to speed up decoding. @@ -645,13 +709,18 @@ def forward( """ # different to other models, MBart automatically creates decoder_input_ids from # input MBartForSequenceClassification_ids if no decoder_input_ids are provided - if input_ids is None and encoder_output is None: - raise ValueError("You have to specify either input_ids or encoder_output") - if decoder_input_ids is None: - assert input_ids is not None, "input_ids should be " "specified when generating decoder_input_ids" + if input_ids is None and inputs_embeds is None and encoder_output is None: + raise ValueError("You have to specify one of input_ids, inputs_embeds and encoder_output") + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) decoder_input_ids = shift_tokens_right(input_ids, self.pad_token_id) - if attention_mask is None: - assert input_ids is not None, "input_ids should be " "specified when generating attention_mask" + if attention_mask is None and input_ids is not None: + logger.warning("input_ids should be specified when generating attention_mask") attention_mask = ( paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4 ) @@ -660,20 +729,24 @@ def forward( attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) attention_mask = (1.0 - attention_mask) * -1e4 attention_mask.stop_gradient = True + + input_type = type(decoder_input_ids) if decoder_input_ids is not None else type(decoder_inputs_embeds) + if encoder_output is None: encoder_output = self.encoder( input_ids, attention_mask, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_output, ModelOutput): - if isinstance(encoder_output, type(decoder_input_ids)): + if isinstance(encoder_output, input_type): encoder_output = (encoder_output,) encoder_output = convert_encoder_output(encoder_output) - if isinstance(encoder_output, type(decoder_input_ids)): + if isinstance(encoder_output, input_type): encoder_last_hidden_state = encoder_output else: encoder_last_hidden_state = encoder_output[0] @@ -689,15 +762,16 @@ def forward( encoder_last_hidden_state, attention_mask, cache, + decoder_inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: - if isinstance(decoder_output, type(decoder_input_ids)): + if isinstance(decoder_output, input_type): decoder_output = (decoder_output,) - if isinstance(encoder_output, type(decoder_input_ids)): + if isinstance(encoder_output, input_type): encoder_output = (encoder_output,) return decoder_output + encoder_output @@ -767,13 +841,15 @@ def __init__(self, mbart, num_labels=2, dropout=None): def forward( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, use_cache=False, cache=None, + inputs_embeds=None, + decoder_inputs_embeds=None, labels=None, output_attentions=False, output_hidden_states=False, @@ -783,7 +859,7 @@ def forward( The MBartForSequenceClassification forward method, overrides the __call__() special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`MBartModel`. attention_mask (Tensor, optional): See :class:`MBartModel`. @@ -797,6 +873,10 @@ def forward( See :class:`MBartModel`. cache (Tensor, optional): See :class:`MBartModel`. + inputs_embeds (Tensor, optional): + See :class:`MBartModel`. + decoder_inputs_embeds (Tensor, optional): + See :class:`MBartModel`. labels (Tensor, optional): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., num_labels - 1]`. If `num_labels > 1` a classification loss is computed (Cross-Entropy). @@ -830,26 +910,39 @@ def forward( inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ + if labels is not None: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + + if input_ids is None and inputs_embeds is not None: + logger.warning( + f"{self.__class__.__name__} will not detect eos tokens in `inputs_embeds`. Results may be " + "unexpected if using eos tokens in conjunction with `inputs_embeds.`" + ) + outputs = self.mbart( input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_output, - use_cache, - cache, + use_cache=use_cache, + cache=cache, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) output = outputs[0] - eos_mask = paddle.cast(input_ids == self.mbart.config["eos_token_id"], dtype="int64") - if len(paddle.unique(paddle.sum(eos_mask, axis=1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - output_shape = paddle.shape(output) - # TODO(gongenlei): support bool tensor index - output = output.masked_select(eos_mask.unsqueeze(-1).astype("bool").tile([1, 1, output_shape[-1]])) + if input_ids is not None: + eos_mask = paddle.cast(input_ids == self.mbart.config["eos_token_id"], dtype="int64") + if len(paddle.unique(paddle.sum(eos_mask, axis=1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + + # TODO(gongenlei): support bool tensor index + output = output.masked_select(eos_mask.unsqueeze(-1).astype("bool").tile([1, 1, output_shape[-1]])) sentence_representation = output.reshape([output_shape[0], -1, output_shape[-1]])[:, -1, :] logits = self.classifier(sentence_representation) @@ -902,13 +995,15 @@ def __init__(self, mbart): def forward( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, use_cache=False, cache=None, + inputs_embeds=None, + decoder_inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=False, @@ -919,7 +1014,7 @@ def forward( The MBartForQuestionAnswering forward method, overrides the __call__() special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`MBartModel`. attention_mask (Tensor, optional): See :class:`MBartModel`. @@ -929,6 +1024,10 @@ def forward( See :class:`MBartModel`. encoder_output (Tensor, optonal): See :class:`MBartModel`. + inputs_embeds (Tensor, optional): + See :class:`MBartModel`. + decoder_inputs_embeds (Tensor, optional): + See :class:`MBartModel`. use_cache (bool, optional): See :class:`MBartModel`. cache (Tensor, optional): @@ -983,14 +1082,21 @@ def forward( start_logits = outputs[0] end_logits =outputs[1] """ + if start_positions is not None and end_positions is not None: + logger.warning( + "The `use_cache` argument is changed to `False` since `start_positions` and `end_positions` are provided." + ) + use_cache = False outputs = self.mbart( input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_output, - use_cache, - cache, + use_cache=use_cache, + cache=cache, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1007,7 +1113,7 @@ def forward( if start_positions.ndim > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.shape[1] + ignored_index = paddle.shape(start_logits)[1] start_positions = start_positions.clip(0, ignored_index) end_positions = end_positions.clip(0, ignored_index) @@ -1083,13 +1189,15 @@ def prepare_faster_entry(self, kwargs): def forward( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, use_cache=False, cache=None, + inputs_embeds=None, + decoder_inputs_embeds=None, labels=None, output_attentions=False, output_hidden_states=False, @@ -1099,7 +1207,7 @@ def forward( The MBartForConditionalGeneration forward method, overrides the __call__() special method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`MBartModel`. attention_mask (Tensor, optional): See :class:`MBartModel`. @@ -1109,11 +1217,15 @@ def forward( See :class:`MBartModel`. encoder_output (Tensor, optonal): See :class:`MBartModel`. + See :class:`MBartModel`. use_cache (bool, optional): See :class:`MBartModel`. cache (Tensor, optional): See :class:`MBartModel`. - abels (Tensor, optional): + inputs_embeds (Tensor, optional): + See :class:`MBartModel`. + decoder_inputs_embeds (Tensor, optional): + labels (Tensor, optional): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., vocab_size]`. @@ -1156,14 +1268,21 @@ def forward( outputs = model(**inputs) """ + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + outputs = self.mbart( input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_output, - use_cache, - cache, + use_cache=use_cache, + cache=cache, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, diff --git a/paddlenlp/transformers/t5/modeling.py b/paddlenlp/transformers/t5/modeling.py index 36a4e617e6d8..861d5a343a97 100644 --- a/paddlenlp/transformers/t5/modeling.py +++ b/paddlenlp/transformers/t5/modeling.py @@ -1385,6 +1385,8 @@ def forward( logits = output[1] """ + + input_type = type(decoder_input_ids) if decoder_input_ids is not None else type(decoder_inputs_embeds) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) @@ -1399,7 +1401,7 @@ def forward( return_dict=return_dict, ) else: - if isinstance(encoder_output, type(decoder_input_ids)): + if isinstance(encoder_output, input_type): encoder_output = (encoder_output,) if return_dict and not isinstance(encoder_output, BaseModelOutput): encoder_output = convert_encoder_output(encoder_output) diff --git a/paddlenlp/transformers/unified_transformer/modeling.py b/paddlenlp/transformers/unified_transformer/modeling.py index 1b37eca6a702..410c9d0fc320 100644 --- a/paddlenlp/transformers/unified_transformer/modeling.py +++ b/paddlenlp/transformers/unified_transformer/modeling.py @@ -16,8 +16,8 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddle.nn import TransformerEncoder +from ...utils.log import logger from .. import PretrainedModel, register_base_model from ..model_outputs import CausalLMOutputWithCrossAttentions @@ -169,24 +169,38 @@ def __init__( self.pad_token_id = pad_token_id - def forward(self, input_ids, token_type_ids=None, position_ids=None, role_ids=None): + def forward(self, input_ids, token_type_ids=None, position_ids=None, role_ids=None, input_embeddings=None): + if input_ids is None and input_embeddings is None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + inputs_shape = paddle.shape(input_ids) + elif input_embeddings is not None: + inputs_shape = paddle.shape(input_embeddings)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + if input_embeddings is None: + input_embeddings = self.word_embeddings(input_ids) + if position_ids is None: if self.pad_token_id is None: - position_ids = paddle.expand_as( - paddle.arange(end=paddle.shape(input_ids)[1], dtype="int64"), input_ids - ) + position_ids = paddle.expand(paddle.arange(end=inputs_shape[1], dtype="int64"), inputs_shape) else: - # NOTE: If there is a unk_token_id in input_ids, the following logic is wrong. - # In that case, the position_ids must be provided. - # And this is for left padding input_ids. - num_pad = paddle.sum((input_ids == self.pad_token_id).astype("float32"), axis=-1, keepdim=True) - position_ids = F.relu( - paddle.expand_as(paddle.arange(end=paddle.shape(input_ids)[1], dtype="float32"), input_ids) - - num_pad - ).astype("int64") + if input_ids is not None: + # NOTE: If there is a unk_token_id in input_ids, the following logic is wrong. + # In that case, the position_ids must be provided. + # And this is for left padding input_ids. + num_pad = paddle.sum((input_ids == self.pad_token_id).astype("float32"), axis=-1, keepdim=True) + position_ids = F.relu( + paddle.expand(paddle.arange(end=inputs_shape[1], dtype="int64"), inputs_shape) - num_pad + ).astype("int64") + else: + logger.warning( + "Position_ids or pad_token_ids should be provided when input_embeds is specified, " + "otherwise an unexpected result may be returned since `[0, 1, ..., sequence length - 1]` will be generated as a default position_ids." + ) + position_ids = paddle.expand(paddle.arange(end=inputs_shape[1], dtype="int64"), inputs_shape) position_ids.stop_gradient = True - input_embedings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) if token_type_ids is None: @@ -194,7 +208,7 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, role_ids=No token_type_ids.stop_gradient = True token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = input_embedings + position_embeddings + token_type_embeddings + embeddings = input_embeddings + position_embeddings + token_type_embeddings # A model with role_embeddings can generate without role_ids. if role_ids is not None: embeddings += self.role_embeddings(role_ids) @@ -332,13 +346,14 @@ def set_input_embeddings(self, value): def forward( self, - input_ids, + input_ids=None, token_type_ids=None, position_ids=None, attention_mask=None, use_cache=False, cache=None, role_ids=None, + inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -348,7 +363,7 @@ def forward( :meth:`__call__` method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It's data type should be `int64` and has a shape of @@ -391,6 +406,11 @@ def forward( Indices of role ids indicated different roles. It's data type should be `int64` and has a shape of [batch_size, sequence_length]. Defaults to None. + inputs_embeds (Tensor, optional): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation + of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over + how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + Default to None. output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. Defaults to `False`. @@ -427,13 +447,22 @@ def forward( is_split_into_words=False) outputs = model(**inputs) """ + if attention_mask is None: - attention_mask = ((input_ids == self.pad_token_id).astype(paddle.get_default_dtype()) * -1e4).unsqueeze( - [1, 2] - ) + if input_ids is not None: + attention_mask = ( + (input_ids == self.pad_token_id).astype(paddle.get_default_dtype()) * -1e4 + ).unsqueeze([1, 2]) + else: + logger.warning( + "Provided inputs_embeds while attention_mask is None, attention weights will not be masked during forwarding." + ) + if attention_mask is not None: attention_mask.stop_gradient = True - embedding_output = self.embeddings(input_ids, token_type_ids, position_ids, role_ids=role_ids) + embedding_output = self.embeddings( + input_ids, token_type_ids, position_ids, role_ids=role_ids, input_embeddings=inputs_embeds + ) if use_cache and cache is None: cache = self.encoder.gen_cache(embedding_output) @@ -495,7 +524,7 @@ def __init__(self, unified_transformer): def forward( self, - input_ids, + input_ids=None, token_type_ids=None, position_ids=None, attention_mask=None, @@ -504,6 +533,7 @@ def forward( cache=None, role_ids=None, labels=None, + inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -513,7 +543,7 @@ def forward( :meth:`__call__` method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`UnifiedTransformerModel`. token_type_ids (Tensor): See :class:`UnifiedTransformerModel`. @@ -531,6 +561,8 @@ def forward( Labels for computing the left-to-right language modeling loss. Indices should be in `[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., vocab_size]` + inputs_embeds (Tensor, optional): + See :class:`UnifiedTransformerModel`. output_attentions (bool, optional): See :class: `UnifiedTransformerModel` output_hidden_states (bool, optional): @@ -573,11 +605,13 @@ def forward( use_cache, cache, role_ids=role_ids, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - sequence_output = outputs if isinstance(outputs, type(input_ids)) else outputs[0] + input_type = type(input_ids) if input_ids is not None else type(inputs_embeds) + sequence_output = outputs if isinstance(outputs, input_type) else outputs[0] logits = self.lm_head(sequence_output, masked_positions) lm_loss = None @@ -585,7 +619,7 @@ def forward( loss_fct = nn.CrossEntropyLoss() lm_loss = loss_fct(logits.reshape((-1, logits.shape[-1])), labels.reshape([-1])) if not return_dict: - if isinstance(outputs, type(input_ids)): + if isinstance(outputs, input_type): return (lm_loss, logits) if lm_loss is not None else logits else: outputs = (logits,) + outputs[1:] diff --git a/paddlenlp/transformers/unimo/modeling.py b/paddlenlp/transformers/unimo/modeling.py index c2b5e1eb7aa0..02cefa832369 100644 --- a/paddlenlp/transformers/unimo/modeling.py +++ b/paddlenlp/transformers/unimo/modeling.py @@ -16,8 +16,8 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddle.nn import TransformerEncoder +from ...utils.log import logger from .. import PretrainedModel, register_base_model from ..model_outputs import CausalLMOutputWithCrossAttentions @@ -242,20 +242,33 @@ def __init__( self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) self.pad_token_id = pad_token_id - def forward(self, input_ids, token_type_ids=None, position_ids=None): - input_embedings = self.word_embeddings(input_ids) + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeddings=None): + if input_ids is None and input_embeddings is None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + inputs_shape = paddle.shape(input_ids) + elif input_embeddings is not None: + inputs_shape = paddle.shape(input_embeddings)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + if input_embeddings is None: + input_embeddings = self.word_embeddings(input_ids) if position_ids is None: if self.pad_token_id is None: - position_ids = paddle.expand_as( - paddle.arange(end=paddle.shape(input_ids)[1], dtype="int64"), input_ids - ) + position_ids = paddle.expand_as(paddle.arange(end=inputs_shape[1], dtype="int64"), inputs_shape) else: - num_pad = paddle.sum((input_ids == self.pad_token_id).astype("float32"), axis=-1, keepdim=True) - position_ids = F.relu( - paddle.expand_as(paddle.arange(end=paddle.shape(input_ids)[1], dtype="float32"), input_ids) - - num_pad - ).astype("int64") + if input_ids is not None: + num_pad = paddle.sum((input_ids == self.pad_token_id).astype("float32"), axis=-1, keepdim=True) + position_ids = F.relu( + paddle.expand_as(paddle.arange(end=inputs_shape[1], dtype="int64"), inputs_shape) - num_pad + ).astype("int64") + else: + logger.warning( + "Position_ids or pad_token_ids should be provided when input_embeds is specified, " + "otherwise an unexpected result may be returned since `[0, 1, ..., sequence length - 1]` will be generated as a default position_ids." + ) + position_ids = paddle.expand_as(paddle.arange(end=inputs_shape[1], dtype="int64"), inputs_shape) position_ids.stop_gradient = True position_embeddings = self.position_embeddings(position_ids) @@ -264,7 +277,7 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None): token_type_ids.stop_gradient = True token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = input_embedings + position_embeddings + token_type_embeddings + embeddings = input_embeddings + position_embeddings + token_type_embeddings return embeddings @@ -403,12 +416,13 @@ def set_input_embeddings(self, value): def forward( self, - input_ids, + input_ids=None, token_type_ids=None, position_ids=None, attention_mask=None, use_cache=False, cache=None, + inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False, @@ -417,7 +431,7 @@ def forward( The UNIMOModel forward method, overrides the special :meth:`__call__` method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It's data type should be `int64` and has a shape of [batch_size, sequence_length]. @@ -455,6 +469,11 @@ def forward( method. See :meth:`paddle.nn.TransformerEncoder.gen_cache` method for more details. It is only used for inference and should be None for training. Defaults to `None`. + inputs_embeds (Tensor, optional): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation + of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over + how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + Default to None. output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. Defaults to `False`. @@ -486,13 +505,21 @@ def forward( inputs = tokenizer.gen_encode("Welcome to use PaddlePaddle and PaddleNLP!", return_tensors=True) outputs = model(**inputs) """ + if attention_mask is None: - attention_mask = ((input_ids == self.pad_token_id).astype(paddle.get_default_dtype()) * -1e4).unsqueeze( - [1, 2] - ) + if input_ids is not None: + attention_mask = ( + (input_ids == self.pad_token_id).astype(paddle.get_default_dtype()) * -1e4 + ).unsqueeze([1, 2]) + else: + logger.warning( + "Provided inputs_embeds while attention_mask is None, attention weights will not be masked during forwarding." + ) + + if attention_mask is not None: attention_mask.stop_gradient = True - embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) + embedding_output = self.embeddings(input_ids, token_type_ids, position_ids, inputs_embeds) embedding_output = self.encoder_norm(embedding_output) embedding_output = self.dropout(embedding_output) @@ -557,13 +584,14 @@ def __init__(self, unimo): def forward( self, - input_ids, + input_ids=None, token_type_ids=None, position_ids=None, attention_mask=None, masked_positions=None, use_cache=False, cache=None, + inputs_embeds=None, labels=None, output_attentions=False, output_hidden_states=False, @@ -574,7 +602,7 @@ def forward( :meth:`__call__` method. Args: - input_ids (Tensor): + input_ids (Tensor, optional): See :class:`UNIMOModel`. token_type_ids (Tensor): See :class:`UNIMOModel`. @@ -586,6 +614,8 @@ def forward( See :class:`UNIMOModel`. cache (list, optional): See :class:`UNIMOModel`. + inputs_embeds (Tensor, optional): + See :class:`UNIMOModel`. labels (Tensor, optional): Labels for computing the left-to-right language modeling loss. Indices should be in `[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are @@ -630,12 +660,13 @@ def forward( attention_mask, use_cache, cache, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - - sequence_output = outputs if isinstance(outputs, type(input_ids)) else outputs[0] + input_type = type(input_ids) if input_ids is not None else type(inputs_embeds) + sequence_output = outputs if isinstance(outputs, input_type) else outputs[0] logits = self.lm_head(sequence_output, masked_positions) @@ -645,7 +676,7 @@ def forward( lm_loss = loss_fct(logits.reshape((-1, self.unimo.config["vocab_size"])), labels.reshape((-1,))) if not return_dict: - if isinstance(outputs, type(input_ids)): + if isinstance(outputs, input_type): return (lm_loss, logits) if lm_loss is not None else logits else: outputs = (logits,) + outputs[1:] diff --git a/tests/transformers/bart/test_modeling.py b/tests/transformers/bart/test_modeling.py index 0a7e90665fc6..dfec8a7fca31 100644 --- a/tests/transformers/bart/test_modeling.py +++ b/tests/transformers/bart/test_modeling.py @@ -13,30 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import tempfile -import unittest -import numpy as np import random -from parameterized import parameterized_class - -from tests.testing_utils import slow - -from ..test_generation_utils import GenerationTesterMixin -from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor -from paddlenlp.transformers.tokenizer_utils_base import PaddingStrategy, TruncationStrategy +import unittest +import numpy as np import paddle +from parameterized import parameterized_class from paddlenlp.transformers import ( - AutoModelForSequenceClassification, BartForConditionalGeneration, BartForQuestionAnswering, BartForSequenceClassification, BartModel, BartTokenizer, ) -from paddlenlp.transformers.bart.modeling import BartDecoder, BartEncoder, shift_tokens_right +from paddlenlp.transformers.bart.modeling import shift_tokens_right +from paddlenlp.transformers.tokenizer_utils_base import ( + PaddingStrategy, + TruncationStrategy, +) +from tests.testing_utils import slow + +from ..test_generation_utils import GenerationTesterMixin +from ..test_modeling_common import ModelTesterMixin, ids_tensor def prepare_bart_inputs_dict( @@ -408,6 +407,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): test_missing_keys = False use_labels = False return_dict = False + use_test_inputs_embeds = True def setUp(self): self.model_tester = BartModelTester(self) @@ -873,24 +873,22 @@ def test_cnn_summarization_same_as_fairseq(self): max_length=1024, ) - EXPECTED = [ - "A French prosecutor says he is not aware of any video footage from on board the plane. Two German " - "magazines claim to have found a cell phone video showing the crash. The publications say they watched " - "the video, which was found by a source close to the investigation. All 150 on board Germanwings Flight " - "9525 were killed.", - "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court " - "jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the " - "Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a " - "move toward greater justice.", - "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The " - "debate that has already begun will likely result in more heat than light. He says critics have made " - "dubious assumptions and doubtful assertions. Bergen says the goal was to block Iran from building a " - "nuclear weapon.", - "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors " - "say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the " - "Bronx on Friday. If convicted, she faces up to four years in prison.", - ] - - generated_summaries = tok.batch_decode( - hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True - ) + # EXPECTED = [ + # "A French prosecutor says he is not aware of any video footage from on board the plane. Two German " + # "magazines claim to have found a cell phone video showing the crash. The publications say they watched " + # "the video, which was found by a source close to the investigation. All 150 on board Germanwings Flight " + # "9525 were killed.", + # "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court " + # "jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the " + # "Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a " + # "move toward greater justice.", + # "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The " + # "debate that has already begun will likely result in more heat than light. He says critics have made " + # "dubious assumptions and doubtful assertions. Bergen says the goal was to block Iran from building a " + # "nuclear weapon.", + # "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors " + # "say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the " + # "Bronx on Friday. If convicted, she faces up to four years in prison.", + # ] + + tok.batch_decode(hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True) diff --git a/tests/transformers/codegen/test_modeling.py b/tests/transformers/codegen/test_modeling.py index 8409d05b1291..0912486c8a91 100644 --- a/tests/transformers/codegen/test_modeling.py +++ b/tests/transformers/codegen/test_modeling.py @@ -13,24 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import unittest -import numpy as np import random +import unittest +import numpy as np import paddle +from parameterized import parameterized_class + from paddlenlp.transformers import ( CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST, AutoTokenizer, CodeGenForCausalLM, CodeGenModel, - CodeGenTokenizer, ) -from ...testing_utils import slow +from ...testing_utils import slow from ..test_generation_utils import GenerationTesterMixin -from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask -from parameterized import parameterized_class +from ..test_modeling_common import ( + ModelTesterMixin, + floats_tensor, + ids_tensor, + random_attention_mask, +) class CodeGenModelTester: @@ -299,7 +303,7 @@ def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, loss, logits = model(input_ids, return_dict=self.parent.return_dict, labels=input_ids)[:2] self.parent.assertEqual(loss.shape, [1]) self.parent.assertEqual(logits.shape, [self.batch_size, self.seq_length, self.vocab_size]) - result.loss.backward() + loss.backward() def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -341,6 +345,7 @@ class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas use_test_model_name_list = False return_dict = False use_labels = False + use_test_inputs_embeds = True # attention mask issue def _get_input_ids_and_config(self): @@ -446,7 +451,7 @@ def test_model_name_list(self): @slow def test_auto_tokenizer(self): for model_name in CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST: - tokenizer = AutoTokenizer.from_pretrained(model_name) + AutoTokenizer.from_pretrained(model_name) class CodeGenModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/transformers/mbart/test_modeling.py b/tests/transformers/mbart/test_modeling.py index fd95e2f9d65b..c2fda789587b 100644 --- a/tests/transformers/mbart/test_modeling.py +++ b/tests/transformers/mbart/test_modeling.py @@ -13,15 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import tempfile -from tests.testing_utils import slow, PaddleNLPModelTest - -from ..test_generation_utils import GenerationTesterMixin -from ..test_modeling_common import ModelTesterMixin, ids_tensor -from parameterized import parameterized_class - import paddle +from parameterized import parameterized_class from paddlenlp.transformers import ( AutoTokenizer, @@ -30,7 +26,11 @@ MBartForSequenceClassification, MBartModel, ) -from paddlenlp.transformers.mbart.modeling import MBartDecoder, MBartEncoder +from paddlenlp.transformers.mbart.modeling import MBartDecoder +from tests.testing_utils import PaddleNLPModelTest, slow + +from ..test_generation_utils import GenerationTesterMixin +from ..test_modeling_common import ModelTesterMixin, ids_tensor def prepare_mbart_inputs_dict( @@ -221,12 +221,48 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model2 = model_class.from_pretrained(tmpdirname) + model_class.from_pretrained(tmpdirname) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_inputs_embeds_for_mbart(self): + # NOTE: rewrite test inputs embeds for mbart model since scaler not equal to 1.0 + # get config for model and inputs_dict for model forward + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + scaler = config["d_model"] ** 0.5 + # test all model classes + for model_class in self.all_model_classes: + model = self._make_model_instance(config, model_class) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + with paddle.no_grad(): + ids_output = model(**inputs) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) * scaler + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) * scaler + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) * scaler + + with paddle.no_grad(): + embeds_output = model(**inputs) + + self.assertTrue(paddle.allclose(ids_output, embeds_output, rtol=1e-4, atol=1e-4)) + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" @@ -459,7 +495,7 @@ def create_and_check_decoder_model_past( # first forward pass outputs = model(input_ids, cache=origin_cache, return_dict=self.parent.return_dict) - outputs_use_cache_conf = model(input_ids, return_dict=self.parent.return_dict) + # outputs_use_cache_conf = model(input_ids, return_dict=self.parent.return_dict) outputs_no_past = model(input_ids, cache=None, return_dict=self.parent.return_dict) # self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) # didn't support using cache by config yet @@ -508,9 +544,6 @@ def create_and_check_decoder_model_attention_mask_past( encoder_output = paddle.randn(shape=input_ids.shape + [self.d_model]) origin_cache = model.decoder.gen_cache(encoder_output) - cache = model.decoder.gen_cache( - paddle.randn(shape=[input_ids.shape[0], input_ids.shape[1], config["d_model"]]) - ) # first forward pass past_key_values = model( diff --git a/tests/transformers/test_generation_utils.py b/tests/transformers/test_generation_utils.py index 02f352d26638..d5eb559c3e03 100644 --- a/tests/transformers/test_generation_utils.py +++ b/tests/transformers/test_generation_utils.py @@ -399,7 +399,7 @@ def test_greedy_generate(self): for model_class in self.all_generative_model_classes.keys(): config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - paddle.seed(128) + paddle.seed(124) model = self._make_model_instance(config, model_class) model.eval() @@ -414,7 +414,7 @@ def test_sample_generate(self): for model_class in self.all_generative_model_classes.keys(): config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - paddle.seed(128) + paddle.seed(124) model = self._make_model_instance(config, model_class) model.eval() diff --git a/tests/transformers/unified_transformer/test_modeling.py b/tests/transformers/unified_transformer/test_modeling.py index b73c119c16f5..851b80fdb436 100644 --- a/tests/transformers/unified_transformer/test_modeling.py +++ b/tests/transformers/unified_transformer/test_modeling.py @@ -12,28 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import math -import unittest -import numpy as np import random +import unittest -from tests.testing_utils import slow -from parameterized import parameterized_class - -from ..test_generation_utils import GenerationTesterMixin -from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask - +import numpy as np import paddle import paddle.nn as nn +from parameterized import parameterized_class + +from paddlenlp.data import Pad from paddlenlp.transformers import ( - UnifiedTransformerModel, UnifiedTransformerLMHeadModel, - UnifiedTransformerForMaskedLM, + UnifiedTransformerModel, UnifiedTransformerTokenizer, ) -from paddlenlp.data import Pad -from paddlenlp.data import DataCollatorWithPadding +from tests.testing_utils import slow + +from ..test_generation_utils import GenerationTesterMixin +from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask UNIFIED_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ "unified_transformer-12L-cn", @@ -408,7 +404,7 @@ class UnifiedTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unitt all_model_classes = (UnifiedTransformerModel, UnifiedTransformerLMHeadModel) all_generative_model_classes = {UnifiedTransformerLMHeadModel: (UnifiedTransformerModel, "unified_transformer")} test_missing_keys = False - + use_test_inputs_embeds = True use_labels = False return_dict = False diff --git a/tests/transformers/unimo/test_modeling.py b/tests/transformers/unimo/test_modeling.py index 830e1f0d05fb..b8bbd06f9910 100644 --- a/tests/transformers/unimo/test_modeling.py +++ b/tests/transformers/unimo/test_modeling.py @@ -12,28 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import math +import random import unittest + import numpy as np -import random +import paddle +import paddle.nn as nn from parameterized import parameterized_class +from paddlenlp.data import Pad +from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOModel, UNIMOTokenizer from tests.testing_utils import slow from ..test_generation_utils import GenerationTesterMixin -from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask - -import paddle -import paddle.nn as nn -from paddlenlp.transformers import ( - UNIMOModel, - UNIMOLMHeadModel, - UNIMOForMaskedLM, - UNIMOTokenizer, -) -from paddlenlp.data import Pad -from paddlenlp.data import DataCollatorWithPadding +from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask UNIMO_PRETRAINED_MODEL_ARCHIVE_LIST = [ "unimo-text-1.0", @@ -406,6 +398,7 @@ class UNIMOModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) use_labels = False return_dict = False + use_test_inputs_embeds = True # special case for DoubleHeads model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):