You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: paddlenlp/transformers/codegen/modeling.py
+41-19Lines changed: 41 additions & 19 deletions
Original file line number
Diff line number
Diff line change
@@ -18,6 +18,7 @@
18
18
importpaddle.nn.functionalasF
19
19
frompaddle.nnimportLayer
20
20
21
+
from ...utils.logimportlogger
21
22
from .. importPretrainedModel, register_base_model
22
23
from ..model_outputsimport (
23
24
BaseModelOutputWithPastAndCrossAttentions,
@@ -415,14 +416,15 @@ def forward(
415
416
token_type_ids=None,
416
417
use_cache=False,
417
418
cache=None,
419
+
inputs_embeds=None,
418
420
output_attentions=False,
419
421
output_hidden_states=False,
420
422
return_dict=False,
421
423
):
422
424
r"""
423
425
The CodeGenModel forward method, overrides the `__call__()` special method.
424
426
Args:
425
-
input_ids (Tensor):
427
+
input_ids (Tensor, optional):
426
428
Indices of input sequence tokens in the vocabulary. They are
427
429
numerical representations of tokens that build the input sequence.
428
430
Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
@@ -445,6 +447,11 @@ def forward(
445
447
See `TransformerDecoder.gen_cache <https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/nn/layer/transformer.py#L1060>`__ for more details.
446
448
It is only used for inference and should be None for training.
447
449
Default to `None`.
450
+
inputs_embeds (Tensor, optional):
451
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation
452
+
of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over
453
+
how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
454
+
Default to None.
448
455
output_attentions (bool, optional):
449
456
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
450
457
tensors for more detail. Defaults to `False`.
@@ -473,12 +480,17 @@ def forward(
473
480
output = model(**inputs)
474
481
"""
475
482
476
-
ifinput_idsisnotNone:
483
+
ifinput_idsisnotNoneandinputs_embedsisnotNone:
484
+
raiseValueError("You cannot specify both input_ids and inputs_embeds at the same time")
0 commit comments