Skip to content

add inputs_embeds to Bart/MBart/Unified_Transformer/Unimo/CodeGen #3769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8a99eb4
fix bug for t5 which will occured when encoder_output is not None
Yam0214 Nov 15, 2022
460f3a3
Merge branch 'develop' of github.com:PaddlePaddle/PaddleNLP into inpu…
Yam0214 Nov 15, 2022
6174326
add inputs_embeds to bart and force use_cache=False when labels is pr…
Yam0214 Nov 15, 2022
64b23f2
add inputs_embeds to mbart and force use_cache=False when labels is p…
Yam0214 Nov 15, 2022
53096e1
fix conflicts
Yam0214 Nov 15, 2022
7ff588e
add inputs_embeds to codegen
Yam0214 Nov 16, 2022
ed51fba
add inputs_embeds to unimo
Yam0214 Nov 16, 2022
89e1922
add inputs_embeds to unified
Yam0214 Nov 16, 2022
4ae2205
Merge branch 'develop' of github.com:PaddlePaddle/PaddleNLP into inpu…
Yam0214 Nov 16, 2022
3ffba1b
change assertion to warning with default position_ids
Yam0214 Nov 16, 2022
0569bd2
Merge branch 'develop' of github.com:PaddlePaddle/PaddleNLP into inpu…
Yam0214 Nov 16, 2022
2ce2f74
Merge branch 'codestyle_before' into inputs_embed
Yam0214 Dec 2, 2022
32148d1
merge and fix conflicts
Yam0214 Dec 2, 2022
15112d8
check code style
Yam0214 Dec 5, 2022
005761f
merge and fix conflicts
Yam0214 Dec 5, 2022
2cd6639
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
Yam0214 Dec 6, 2022
8b5766f
change tensor.shape to paddle.shape(tensor)
Yam0214 Dec 6, 2022
cbb5fc9
merge and fix conflicts
Yam0214 Dec 6, 2022
48c1db1
Merge branch 'develop' into inputs_embed
FrostML Dec 8, 2022
7ed06df
fix documntes and change expand_as to expand
Yam0214 Dec 8, 2022
0f884f8
Merge branch 'develop' into inputs_embed
wj-Mcat Dec 9, 2022
eb27b2a
Merge branch 'develop' into inputs_embed
FrostML Dec 9, 2022
556b751
Merge branch 'develop' into inputs_embed
FrostML Dec 9, 2022
791dfce
Merge branch 'develop' into inputs_embed
FrostML Dec 9, 2022
d0aded0
Merge branch 'develop' into inputs_embed
FrostML Dec 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 149 additions & 49 deletions paddlenlp/transformers/bart/modeling.py

Large diffs are not rendered by default.

60 changes: 41 additions & 19 deletions paddlenlp/transformers/codegen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import paddle.nn.functional as F
from paddle.nn import Layer

from ...utils.log import logger
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -415,14 +416,15 @@ def forward(
token_type_ids=None,
use_cache=False,
cache=None,
inputs_embeds=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
r"""
The CodeGenModel forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
input_ids (Tensor, optional):
Indices of input sequence tokens in the vocabulary. They are
numerical representations of tokens that build the input sequence.
Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
Expand All @@ -445,6 +447,11 @@ def forward(
See `TransformerDecoder.gen_cache <https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/nn/layer/transformer.py#L1060>`__ for more details.
It is only used for inference and should be None for training.
Default to `None`.
inputs_embeds (Tensor, optional):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation
of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over
how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Default to None.
output_attentions (bool, optional):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. Defaults to `False`.
Expand Down Expand Up @@ -473,12 +480,17 @@ def forward(
output = model(**inputs)
"""

if input_ids is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.shape
input_ids = input_ids.reshape(shape=(-1, input_shape[-1]))
input_ids = input_ids.reshape((-1, input_shape[-1]))
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify input_ids")
raise ValueError("You have to specify either input_ids or inputs_embeds")

if cache is None:
past_length = 0
Expand All @@ -488,30 +500,36 @@ def forward(

# Attention mask.
if attention_mask is None:
assert input_ids is not None, "input_ids should be " "specified when generating attention_mask"
if batch_size == 1 and past_length != 0:
batch_size, seq_len = input_shape
attention_mask = paddle.zeros(
[batch_size, 1, 1, seq_len + past_length], dtype=paddle.get_default_dtype()
)
if input_ids is not None:
if batch_size == 1 and past_length != 0:
batch_size, seq_len = input_shape
attention_mask = paddle.zeros(
[batch_size, 1, 1, seq_len + past_length], dtype=paddle.get_default_dtype()
)
else:
attention_mask = (
paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2])
* -1e4
)
else:
attention_mask = (
paddle.cast(input_ids == self.pad_token_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2])
* -1e4
logger.warning(
"Provided inputs_embeds while attention_mask is None, attention weights will not be masked during forwarding."
)
# For 2D attention_mask from tokenizer
elif attention_mask.ndim == 2:
attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
attention_mask.stop_gradient = True
if attention_mask is not None:
attention_mask.stop_gradient = True
# TODO: CodeGen Attention Mask is TOO confusion.
# When it's 2D, it must be int and it's denoted by 1/0.
# When using model.generate() without providing attention mask
# or using 4D attention mask,
# the attention mask's dtype must be float and it's denoted by 0/-inf.
# Moreover, cannot support 3D attention mask.

inputs_embeds = self.wte(input_ids)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
inputs_embeds = inputs_embeds + token_type_embeds
Expand Down Expand Up @@ -641,14 +659,15 @@ def forward(
use_cache=False,
cache=None,
labels=None,
inputs_embeds=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
r"""
The CodeGenForCausalLM forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
input_ids (Tensor, optional):
See :class:`CodeGenModel`.
attention_mask (Tensor, optional):
See :class:`CodeGenModel`.
Expand All @@ -660,12 +679,14 @@ def forward(
Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., vocab_size]`
inputs_embeds (Tensor, optional):
See :class:`CodeGenModel`.
output_attentions (bool, optional):
See :class: `CodeGenModel`
See :class: `CodeGenModel`.
output_hidden_states (bool, optional):
See :class: `CodeGenModel`
See :class: `CodeGenModel`.
return_dict (bool, optional):
See :class: `CodeGenModel`
See :class: `CodeGenModel`.
Returns:
An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithPastAndCrossAttentions` if
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
Expand All @@ -691,6 +712,7 @@ def forward(
token_type_ids=token_type_ids,
use_cache=use_cache,
cache=cache,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down
Loading