Skip to content

Commit 985a9a4

Browse files
authored
complete t5 more output (#3370)
1 parent bc23b8b commit 985a9a4

File tree

3 files changed

+328
-72
lines changed

3 files changed

+328
-72
lines changed

paddlenlp/transformers/model_outputs.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,135 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
733733
hidden_states: Optional[Tuple[paddle.Tensor]] = None
734734
attentions: Optional[Tuple[paddle.Tensor]] = None
735735
cross_attentions: Optional[Tuple[paddle.Tensor]] = None
736+
737+
738+
@dataclass
739+
class Seq2SeqModelOutput(ModelOutput):
740+
"""
741+
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
742+
decoding.
743+
744+
Args:
745+
last_hidden_state (`paddle.Tensor`):
746+
Sequence of hidden-states at the output of the last layer of the decoder of the model, whose shape is `(batch_size, Sequence_length, hidden_size)`.
747+
748+
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
749+
hidden_size)` is output.
750+
past_key_values (`tuple(tuple(paddle.Tensor))`, optional):
751+
Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
752+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
753+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
754+
Returned when `use_cache=True` is passed or when `config.use_cache=True`.
755+
756+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
757+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
758+
decoder_hidden_states (`tuple(paddle.Tensor)`, optional):
759+
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
760+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
761+
Returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`.
762+
763+
Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
764+
decoder_attentions (`tuple(paddle.Tensor)`, optional):
765+
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
766+
sequence_length)`.
767+
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
768+
769+
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
770+
self-attention heads.
771+
cross_attentions (`tuple(paddle.Tensor)`, optional):
772+
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
773+
sequence_length)`.
774+
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
775+
776+
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
777+
weighted average in the cross-attention heads.
778+
encoder_last_hidden_state (`paddle.Tensor`, optional):
779+
Sequence of hidden-states at the output of the last layer of the encoder of the model whose shape is `(batch_size, sequence_length, hidden_size)`,
780+
encoder_hidden_states (`tuple(paddle.Tensor)`, optional):
781+
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
782+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
783+
Returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`.
784+
785+
Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
786+
encoder_attentions (`tuple(paddle.Tensor)`, optional):
787+
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
788+
sequence_length)`.
789+
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
790+
791+
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
792+
self-attention heads.
793+
"""
794+
795+
last_hidden_state: paddle.Tensor = None
796+
past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None
797+
decoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
798+
decoder_attentions: Optional[Tuple[paddle.Tensor]] = None
799+
cross_attentions: Optional[Tuple[paddle.Tensor]] = None
800+
encoder_last_hidden_state: Optional[paddle.Tensor] = None
801+
encoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
802+
encoder_attentions: Optional[Tuple[paddle.Tensor]] = None
803+
804+
805+
@dataclass
806+
class Seq2SeqLMOutput(ModelOutput):
807+
"""
808+
Base class for sequence-to-sequence language models outputs.
809+
810+
Args:
811+
loss (`paddle.Tensor`, optional):
812+
Language modeling loss whose shape is `(1,)`. Returned when `labels` is provided.
813+
logits (`paddle.Tensor`):
814+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) whose shape is `(batch_size, sequence_length, config.vocab_size)`).
815+
past_key_values (`tuple(tuple(paddle.Tensor))`, optional):
816+
Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
817+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
818+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
819+
Returned when `use_cache=True` is passed or when `config.use_cache=True`.
820+
821+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
822+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
823+
decoder_hidden_states (`tuple(paddle.Tensor)`, optional):
824+
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
825+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
826+
Returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`.
827+
828+
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
829+
decoder_attentions (`tuple(paddle.Tensor)`, optional):
830+
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
831+
sequence_length)`.
832+
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
833+
834+
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
835+
self-attention heads.
836+
cross_attentions (`tuple(paddle.Tensor)`, optional):
837+
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
838+
sequence_length)`.
839+
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
840+
841+
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
842+
weighted average in the cross-attention heads.
843+
encoder_last_hidden_state (`paddle.Tensor`, optional):
844+
Sequence of hidden-states at the output of the last layer of the encoder of the model whose shape is `(batch_size, sequence_length, hidden_size)`.
845+
encoder_hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
846+
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
847+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
848+
849+
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
850+
encoder_attentions (`tuple(paddle.Tensor)`, optional):
851+
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
852+
sequence_length)`.
853+
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
854+
855+
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
856+
self-attention heads.
857+
"""
858+
859+
loss: Optional[paddle.Tensor] = None
860+
logits: paddle.Tensor = None
861+
past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None
862+
decoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
863+
decoder_attentions: Optional[Tuple[paddle.Tensor]] = None
864+
cross_attentions: Optional[Tuple[paddle.Tensor]] = None
865+
encoder_last_hidden_state: Optional[paddle.Tensor] = None
866+
encoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
867+
encoder_attentions: Optional[Tuple[paddle.Tensor]] = None

paddlenlp/transformers/t5/modeling.py

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626

2727
from ..model_utils import PretrainedModel, register_base_model
2828
from ..nezha.modeling import ACT2FN
29+
from ..model_outputs import (
30+
BaseModelOutputWithPastAndCrossAttentions,
31+
Seq2SeqModelOutput,
32+
Seq2SeqLMOutput,
33+
BaseModelOutput,
34+
ModelOutput,
35+
)
2936

3037
__all__ = [
3138
'T5Model', "T5PretrainedModel", 'T5ForConditionalGeneration',
@@ -944,7 +951,8 @@ def forward(self,
944951
cache=None,
945952
use_cache=False,
946953
output_attentions=False,
947-
output_hidden_states=False):
954+
output_hidden_states=False,
955+
return_dict=False):
948956
assert input_ids is not None, "input_ids can not be None"
949957
input_shape = input_ids.shape
950958
input_ids = input_ids.reshape(shape=[-1, input_shape[-1]])
@@ -1051,13 +1059,22 @@ def forward(self,
10511059
if output_hidden_states:
10521060
all_hidden_states = all_hidden_states + (hidden_states, )
10531061

1054-
return tuple(v for v in [
1055-
hidden_states,
1056-
present_key_value_states,
1057-
all_hidden_states,
1058-
all_attentions,
1059-
all_cross_attentions,
1060-
] if v is not None)
1062+
if not return_dict:
1063+
return tuple(v for v in [
1064+
hidden_states,
1065+
present_key_value_states,
1066+
all_hidden_states,
1067+
all_attentions,
1068+
all_cross_attentions,
1069+
] if v is not None)
1070+
1071+
return BaseModelOutputWithPastAndCrossAttentions(
1072+
last_hidden_state=hidden_states,
1073+
past_key_values=present_key_value_states,
1074+
hidden_states=all_hidden_states,
1075+
attentions=all_attentions,
1076+
cross_attentions=all_cross_attentions,
1077+
)
10611078

10621079
def get_extended_attention_mask(self, attention_mask, input_shape):
10631080
if attention_mask.ndim == 3:
@@ -1293,7 +1310,8 @@ def forward(self,
12931310
cache=None,
12941311
use_cache=True,
12951312
output_attentions=False,
1296-
output_hidden_states=False):
1313+
output_hidden_states=False,
1314+
return_dict=False):
12971315
r"""
12981316
The T5Model forward method, overrides the `__call__()` special method.
12991317
@@ -1343,8 +1361,16 @@ def forward(self,
13431361
output_hidden_states (bool, optional):
13441362
Whether or not to return the output of all hidden layers.
13451363
Defaults to `False`.
1364+
return_dict (bool, optional):
1365+
Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`. If `False`, the output
1366+
will be a tuple of tensors. Defaults to `False`.
1367+
13461368
13471369
Returns:
1370+
An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput` if `return_dict=True`.
1371+
Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
1372+
:class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.
1373+
13481374
tuple: Returns tuple (`last_hidden_state`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
13491375
`cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)
13501376
@@ -1419,8 +1445,10 @@ def forward(self,
14191445
input_ids=input_ids,
14201446
attention_mask=attention_mask,
14211447
output_attentions=output_attentions,
1422-
output_hidden_states=output_hidden_states)
1423-
1448+
output_hidden_states=output_hidden_states,
1449+
return_dict=return_dict)
1450+
elif return_dict and not isinstance(encoder_output, BaseModelOutput):
1451+
encoder_output = convert_encoder_output(encoder_output)
14241452
hidden_states = encoder_output[0]
14251453

14261454
# Decode
@@ -1432,9 +1460,22 @@ def forward(self,
14321460
encoder_attention_mask=attention_mask,
14331461
use_cache=use_cache,
14341462
output_attentions=output_attentions,
1435-
output_hidden_states=output_hidden_states)
1436-
1437-
return decoder_outputs + encoder_output
1463+
output_hidden_states=output_hidden_states,
1464+
return_dict=return_dict)
1465+
1466+
if not return_dict:
1467+
return decoder_outputs + encoder_output
1468+
1469+
return Seq2SeqModelOutput(
1470+
last_hidden_state=decoder_outputs.last_hidden_state,
1471+
past_key_values=decoder_outputs.past_key_values,
1472+
decoder_hidden_states=decoder_outputs.hidden_states,
1473+
decoder_attentions=decoder_outputs.attentions,
1474+
cross_attentions=decoder_outputs.cross_attentions,
1475+
encoder_last_hidden_state=encoder_output.last_hidden_state,
1476+
encoder_hidden_states=encoder_output.hidden_states,
1477+
encoder_attentions=encoder_output.attentions,
1478+
)
14381479

14391480

14401481
class T5ForConditionalGeneration(T5PretrainedModel):
@@ -1490,7 +1531,8 @@ def forward(self,
14901531
labels=None,
14911532
use_cache=True,
14921533
output_attentions=False,
1493-
output_hidden_states=False):
1534+
output_hidden_states=False,
1535+
return_dict=False):
14941536
r"""
14951537
14961538
Args:
@@ -1518,8 +1560,15 @@ def forward(self,
15181560
See :class:`T5Model`.
15191561
output_hidden_states (bool, optional):
15201562
See :class:`T5Model`.
1563+
return_dict (bool, optional):
1564+
Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`. If `False`, the output
1565+
will be a tuple of tensors. Defaults to `False`.
15211566
15221567
Returns:
1568+
An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput` if `return_dict=True`.
1569+
Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
1570+
:class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`.
1571+
15231572
tuple: Returns tuple (`loss`, `logits`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
15241573
`cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)
15251574
@@ -1581,12 +1630,15 @@ def forward(self,
15811630
input_ids=input_ids,
15821631
attention_mask=attention_mask,
15831632
output_attentions=output_attentions,
1584-
output_hidden_states=output_hidden_states)
1585-
1586-
if isinstance(encoder_output, (tuple, list)):
1587-
hidden_states = encoder_output[0]
1633+
output_hidden_states=output_hidden_states,
1634+
return_dict=return_dict)
15881635
else:
1589-
hidden_states = encoder_output
1636+
if isinstance(encoder_output, paddle.Tensor):
1637+
encoder_output = (encoder_output, )
1638+
if return_dict and not isinstance(encoder_output, BaseModelOutput):
1639+
encoder_output = convert_encoder_output(encoder_output)
1640+
1641+
hidden_states = encoder_output[0]
15901642

15911643
if labels is not None and decoder_input_ids is None:
15921644
# get decoder inputs from shifting lm labels to the right
@@ -1610,7 +1662,8 @@ def forward(self,
16101662
encoder_attention_mask=attention_mask,
16111663
use_cache=use_cache,
16121664
output_attentions=output_attentions,
1613-
output_hidden_states=output_hidden_states)
1665+
output_hidden_states=output_hidden_states,
1666+
return_dict=return_dict)
16141667

16151668
sequence_output = decoder_outputs[0]
16161669

@@ -1631,11 +1684,21 @@ def forward(self,
16311684
loss = loss_fct(lm_logits.reshape(shape=[-1, lm_logits.shape[-1]]),
16321685
labels.flatten())
16331686

1634-
if not isinstance(encoder_output, (list, tuple)):
1635-
encoder_output = (encoder_output, )
1636-
1637-
output = (lm_logits, ) + decoder_outputs[1:] + encoder_output
1638-
return ((loss, ) + output) if loss is not None else output
1687+
if not return_dict:
1688+
output = (lm_logits, ) + decoder_outputs[1:] + encoder_output
1689+
return ((loss, ) + output) if loss is not None else output
1690+
1691+
return Seq2SeqLMOutput(
1692+
loss=loss,
1693+
logits=lm_logits,
1694+
past_key_values=decoder_outputs.past_key_values,
1695+
decoder_hidden_states=decoder_outputs.hidden_states,
1696+
decoder_attentions=decoder_outputs.attentions,
1697+
cross_attentions=decoder_outputs.cross_attentions,
1698+
encoder_last_hidden_state=encoder_output.last_hidden_state,
1699+
encoder_hidden_states=encoder_output.hidden_states,
1700+
encoder_attentions=encoder_output.attentions,
1701+
)
16391702

16401703
@staticmethod
16411704
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
@@ -1809,6 +1872,7 @@ def forward(
18091872
use_cache: Optional[bool] = False,
18101873
output_attentions: Optional[bool] = False,
18111874
output_hidden_states: Optional[bool] = False,
1875+
return_dict: Optional[bool] = False,
18121876
):
18131877
encoder_outputs = self.encoder(
18141878
input_ids=input_ids,
@@ -1819,9 +1883,25 @@ def forward(
18191883
use_cache=use_cache,
18201884
output_attentions=output_attentions,
18211885
output_hidden_states=output_hidden_states,
1822-
)
1886+
return_dict=return_dict)
18231887

18241888
return encoder_outputs
18251889

18261890

18271891
T5EncoderModel.base_model_class = T5EncoderModel
1892+
1893+
1894+
def convert_encoder_output(encoder_output):
1895+
"""
1896+
Convert encoder_output from tuple to class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.
1897+
1898+
Args:
1899+
encoder_output (tuple or ModleOutput):
1900+
The output of the encoder, a tuple consists `last_hidden_state`, `hidden_states`(optional), `attentions`(optional).
1901+
The data type of `last_hidden_state` is float32 and its shape is [batch_size, sequence_length, hidden_size].
1902+
"""
1903+
return BaseModelOutput(
1904+
last_hidden_state=encoder_output[0],
1905+
hidden_states=encoder_output[1] if len(encoder_output) > 1 else None,
1906+
attentions=encoder_output[2] if len(encoder_output) > 2 else None,
1907+
)

0 commit comments

Comments
 (0)