Skip to content

Commit 1858416

Browse files
authored
add inputs_embeds to Bart/MBart/Unified_Transformer/Unimo/CodeGen (#3769)
1 parent 271f3c1 commit 1858416

File tree

12 files changed

+585
-252
lines changed

12 files changed

+585
-252
lines changed

paddlenlp/transformers/bart/modeling.py

Lines changed: 149 additions & 49 deletions
Large diffs are not rendered by default.

paddlenlp/transformers/codegen/modeling.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import paddle.nn.functional as F
1919
from paddle.nn import Layer
2020

21+
from ...utils.log import logger
2122
from .. import PretrainedModel, register_base_model
2223
from ..model_outputs import (
2324
BaseModelOutputWithPastAndCrossAttentions,
@@ -415,14 +416,15 @@ def forward(
415416
token_type_ids=None,
416417
use_cache=False,
417418
cache=None,
419+
inputs_embeds=None,
418420
output_attentions=False,
419421
output_hidden_states=False,
420422
return_dict=False,
421423
):
422424
r"""
423425
The CodeGenModel forward method, overrides the `__call__()` special method.
424426
Args:
425-
input_ids (Tensor):
427+
input_ids (Tensor, optional):
426428
Indices of input sequence tokens in the vocabulary. They are
427429
numerical representations of tokens that build the input sequence.
428430
Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
@@ -445,6 +447,11 @@ def forward(
445447
See `TransformerDecoder.gen_cache <https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/nn/layer/transformer.py#L1060>`__ for more details.
446448
It is only used for inference and should be None for training.
447449
Default to `None`.
450+
inputs_embeds (Tensor, optional):
451+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation
452+
of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over
453+
how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
454+
Default to None.
448455
output_attentions (bool, optional):
449456
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
450457
tensors for more detail. Defaults to `False`.
@@ -473,12 +480,17 @@ def forward(
473480
output = model(**inputs)
474481
"""
475482

476-
if input_ids is not None:
483+
if input_ids is not None and inputs_embeds is not None:
484+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
485+
elif input_ids is not None:
477486
input_shape = input_ids.shape
478-
input_ids = input_ids.reshape(shape=(-1, input_shape[-1]))
487+
input_ids = input_ids.reshape((-1, input_shape[-1]))
479488
batch_size = input_ids.shape[0]
489+
elif inputs_embeds is not None:
490+
input_shape = inputs_embeds.shape[:-1]
491+
batch_size = inputs_embeds.shape[0]
480492
else:
481-
raise ValueError("You have to specify input_ids")
493+
raise ValueError("You have to specify either input_ids or inputs_embeds")
482494

483495
if cache is None:
484496
past_length = 0
@@ -488,30 +500,36 @@ def forward(
488500

489501
# Attention mask.
490502
if attention_mask is None:
491-
assert input_ids is not None, "input_ids should be " "specified when generating attention_mask"
492-
if batch_size == 1 and past_length != 0:
493-
batch_size, seq_len = input_shape
494-
attention_mask = paddle.zeros(
495-
[batch_size, 1, 1, seq_len + past_length], dtype=paddle.get_default_dtype()
496-
)
503+
if input_ids is not None:
504+
if batch_size == 1 and past_length != 0:
505+
batch_size, seq_len = input_shape
506+
attention_mask = paddle.zeros(
507+
[batch_size, 1, 1, seq_len + past_length], dtype=paddle.get_default_dtype()
508+
)
509+
else:
510+
attention_mask = (
511+
paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2])
512+
* -1e4
513+
)
497514
else:
498-
attention_mask = (
499-
paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2])
500-
* -1e4
515+
logger.warning(
516+
"Provided inputs_embeds while attention_mask is None, attention weights will not be masked during forwarding."
501517
)
502518
# For 2D attention_mask from tokenizer
503519
elif attention_mask.ndim == 2:
504520
attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
505521
attention_mask = (1.0 - attention_mask) * -1e4
506-
attention_mask.stop_gradient = True
522+
if attention_mask is not None:
523+
attention_mask.stop_gradient = True
507524
# TODO: CodeGen Attention Mask is TOO confusion.
508525
# When it's 2D, it must be int and it's denoted by 1/0.
509526
# When using model.generate() without providing attention mask
510527
# or using 4D attention mask,
511528
# the attention mask's dtype must be float and it's denoted by 0/-inf.
512529
# Moreover, cannot support 3D attention mask.
513530

514-
inputs_embeds = self.wte(input_ids)
531+
if inputs_embeds is None:
532+
inputs_embeds = self.wte(input_ids)
515533
if token_type_ids is not None:
516534
token_type_embeds = self.wte(token_type_ids)
517535
inputs_embeds = inputs_embeds + token_type_embeds
@@ -641,14 +659,15 @@ def forward(
641659
use_cache=False,
642660
cache=None,
643661
labels=None,
662+
inputs_embeds=None,
644663
output_attentions=False,
645664
output_hidden_states=False,
646665
return_dict=False,
647666
):
648667
r"""
649668
The CodeGenForCausalLM forward method, overrides the __call__() special method.
650669
Args:
651-
input_ids (Tensor):
670+
input_ids (Tensor, optional):
652671
See :class:`CodeGenModel`.
653672
attention_mask (Tensor, optional):
654673
See :class:`CodeGenModel`.
@@ -660,12 +679,14 @@ def forward(
660679
Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set
661680
`labels = input_ids` Indices are selected in `[-100, 0, ..., vocab_size]` All labels set to `-100`
662681
are ignored (masked), the loss is only computed for labels in `[0, ..., vocab_size]`
682+
inputs_embeds (Tensor, optional):
683+
See :class:`CodeGenModel`.
663684
output_attentions (bool, optional):
664-
See :class: `CodeGenModel`
685+
See :class: `CodeGenModel`.
665686
output_hidden_states (bool, optional):
666-
See :class: `CodeGenModel`
687+
See :class: `CodeGenModel`.
667688
return_dict (bool, optional):
668-
See :class: `CodeGenModel`
689+
See :class: `CodeGenModel`.
669690
Returns:
670691
An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithPastAndCrossAttentions` if
671692
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
@@ -691,6 +712,7 @@ def forward(
691712
token_type_ids=token_type_ids,
692713
use_cache=use_cache,
693714
cache=cache,
715+
inputs_embeds=inputs_embeds,
694716
output_attentions=output_attentions,
695717
output_hidden_states=output_hidden_states,
696718
return_dict=return_dict,

0 commit comments

Comments
 (0)