26
26
27
27
from ..model_utils import PretrainedModel , register_base_model
28
28
from ..nezha .modeling import ACT2FN
29
+ from ..model_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions ,
31
+ Seq2SeqModelOutput ,
32
+ Seq2SeqLMOutput ,
33
+ BaseModelOutput ,
34
+ ModelOutput ,
35
+ )
29
36
30
37
__all__ = [
31
38
'T5Model' , "T5PretrainedModel" , 'T5ForConditionalGeneration' ,
@@ -944,7 +951,8 @@ def forward(self,
944
951
cache = None ,
945
952
use_cache = False ,
946
953
output_attentions = False ,
947
- output_hidden_states = False ):
954
+ output_hidden_states = False ,
955
+ return_dict = False ):
948
956
assert input_ids is not None , "input_ids can not be None"
949
957
input_shape = input_ids .shape
950
958
input_ids = input_ids .reshape (shape = [- 1 , input_shape [- 1 ]])
@@ -1051,13 +1059,22 @@ def forward(self,
1051
1059
if output_hidden_states :
1052
1060
all_hidden_states = all_hidden_states + (hidden_states , )
1053
1061
1054
- return tuple (v for v in [
1055
- hidden_states ,
1056
- present_key_value_states ,
1057
- all_hidden_states ,
1058
- all_attentions ,
1059
- all_cross_attentions ,
1060
- ] if v is not None )
1062
+ if not return_dict :
1063
+ return tuple (v for v in [
1064
+ hidden_states ,
1065
+ present_key_value_states ,
1066
+ all_hidden_states ,
1067
+ all_attentions ,
1068
+ all_cross_attentions ,
1069
+ ] if v is not None )
1070
+
1071
+ return BaseModelOutputWithPastAndCrossAttentions (
1072
+ last_hidden_state = hidden_states ,
1073
+ past_key_values = present_key_value_states ,
1074
+ hidden_states = all_hidden_states ,
1075
+ attentions = all_attentions ,
1076
+ cross_attentions = all_cross_attentions ,
1077
+ )
1061
1078
1062
1079
def get_extended_attention_mask (self , attention_mask , input_shape ):
1063
1080
if attention_mask .ndim == 3 :
@@ -1293,7 +1310,8 @@ def forward(self,
1293
1310
cache = None ,
1294
1311
use_cache = True ,
1295
1312
output_attentions = False ,
1296
- output_hidden_states = False ):
1313
+ output_hidden_states = False ,
1314
+ return_dict = False ):
1297
1315
r"""
1298
1316
The T5Model forward method, overrides the `__call__()` special method.
1299
1317
@@ -1343,8 +1361,16 @@ def forward(self,
1343
1361
output_hidden_states (bool, optional):
1344
1362
Whether or not to return the output of all hidden layers.
1345
1363
Defaults to `False`.
1364
+ return_dict (bool, optional):
1365
+ Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`. If `False`, the output
1366
+ will be a tuple of tensors. Defaults to `False`.
1367
+
1346
1368
1347
1369
Returns:
1370
+ An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput` if `return_dict=True`.
1371
+ Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
1372
+ :class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.
1373
+
1348
1374
tuple: Returns tuple (`last_hidden_state`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
1349
1375
`cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)
1350
1376
@@ -1419,8 +1445,10 @@ def forward(self,
1419
1445
input_ids = input_ids ,
1420
1446
attention_mask = attention_mask ,
1421
1447
output_attentions = output_attentions ,
1422
- output_hidden_states = output_hidden_states )
1423
-
1448
+ output_hidden_states = output_hidden_states ,
1449
+ return_dict = return_dict )
1450
+ elif return_dict and not isinstance (encoder_output , BaseModelOutput ):
1451
+ encoder_output = convert_encoder_output (encoder_output )
1424
1452
hidden_states = encoder_output [0 ]
1425
1453
1426
1454
# Decode
@@ -1432,9 +1460,22 @@ def forward(self,
1432
1460
encoder_attention_mask = attention_mask ,
1433
1461
use_cache = use_cache ,
1434
1462
output_attentions = output_attentions ,
1435
- output_hidden_states = output_hidden_states )
1436
-
1437
- return decoder_outputs + encoder_output
1463
+ output_hidden_states = output_hidden_states ,
1464
+ return_dict = return_dict )
1465
+
1466
+ if not return_dict :
1467
+ return decoder_outputs + encoder_output
1468
+
1469
+ return Seq2SeqModelOutput (
1470
+ last_hidden_state = decoder_outputs .last_hidden_state ,
1471
+ past_key_values = decoder_outputs .past_key_values ,
1472
+ decoder_hidden_states = decoder_outputs .hidden_states ,
1473
+ decoder_attentions = decoder_outputs .attentions ,
1474
+ cross_attentions = decoder_outputs .cross_attentions ,
1475
+ encoder_last_hidden_state = encoder_output .last_hidden_state ,
1476
+ encoder_hidden_states = encoder_output .hidden_states ,
1477
+ encoder_attentions = encoder_output .attentions ,
1478
+ )
1438
1479
1439
1480
1440
1481
class T5ForConditionalGeneration (T5PretrainedModel ):
@@ -1490,7 +1531,8 @@ def forward(self,
1490
1531
labels = None ,
1491
1532
use_cache = True ,
1492
1533
output_attentions = False ,
1493
- output_hidden_states = False ):
1534
+ output_hidden_states = False ,
1535
+ return_dict = False ):
1494
1536
r"""
1495
1537
1496
1538
Args:
@@ -1518,8 +1560,15 @@ def forward(self,
1518
1560
See :class:`T5Model`.
1519
1561
output_hidden_states (bool, optional):
1520
1562
See :class:`T5Model`.
1563
+ return_dict (bool, optional):
1564
+ Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`. If `False`, the output
1565
+ will be a tuple of tensors. Defaults to `False`.
1521
1566
1522
1567
Returns:
1568
+ An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput` if `return_dict=True`.
1569
+ Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
1570
+ :class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`.
1571
+
1523
1572
tuple: Returns tuple (`loss`, `logits`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
1524
1573
`cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)
1525
1574
@@ -1581,12 +1630,15 @@ def forward(self,
1581
1630
input_ids = input_ids ,
1582
1631
attention_mask = attention_mask ,
1583
1632
output_attentions = output_attentions ,
1584
- output_hidden_states = output_hidden_states )
1585
-
1586
- if isinstance (encoder_output , (tuple , list )):
1587
- hidden_states = encoder_output [0 ]
1633
+ output_hidden_states = output_hidden_states ,
1634
+ return_dict = return_dict )
1588
1635
else :
1589
- hidden_states = encoder_output
1636
+ if isinstance (encoder_output , paddle .Tensor ):
1637
+ encoder_output = (encoder_output , )
1638
+ if return_dict and not isinstance (encoder_output , BaseModelOutput ):
1639
+ encoder_output = convert_encoder_output (encoder_output )
1640
+
1641
+ hidden_states = encoder_output [0 ]
1590
1642
1591
1643
if labels is not None and decoder_input_ids is None :
1592
1644
# get decoder inputs from shifting lm labels to the right
@@ -1610,7 +1662,8 @@ def forward(self,
1610
1662
encoder_attention_mask = attention_mask ,
1611
1663
use_cache = use_cache ,
1612
1664
output_attentions = output_attentions ,
1613
- output_hidden_states = output_hidden_states )
1665
+ output_hidden_states = output_hidden_states ,
1666
+ return_dict = return_dict )
1614
1667
1615
1668
sequence_output = decoder_outputs [0 ]
1616
1669
@@ -1631,11 +1684,21 @@ def forward(self,
1631
1684
loss = loss_fct (lm_logits .reshape (shape = [- 1 , lm_logits .shape [- 1 ]]),
1632
1685
labels .flatten ())
1633
1686
1634
- if not isinstance (encoder_output , (list , tuple )):
1635
- encoder_output = (encoder_output , )
1636
-
1637
- output = (lm_logits , ) + decoder_outputs [1 :] + encoder_output
1638
- return ((loss , ) + output ) if loss is not None else output
1687
+ if not return_dict :
1688
+ output = (lm_logits , ) + decoder_outputs [1 :] + encoder_output
1689
+ return ((loss , ) + output ) if loss is not None else output
1690
+
1691
+ return Seq2SeqLMOutput (
1692
+ loss = loss ,
1693
+ logits = lm_logits ,
1694
+ past_key_values = decoder_outputs .past_key_values ,
1695
+ decoder_hidden_states = decoder_outputs .hidden_states ,
1696
+ decoder_attentions = decoder_outputs .attentions ,
1697
+ cross_attentions = decoder_outputs .cross_attentions ,
1698
+ encoder_last_hidden_state = encoder_output .last_hidden_state ,
1699
+ encoder_hidden_states = encoder_output .hidden_states ,
1700
+ encoder_attentions = encoder_output .attentions ,
1701
+ )
1639
1702
1640
1703
@staticmethod
1641
1704
def prepare_input_ids_for_generation (bos_token_id , encoder_output = None ):
@@ -1809,6 +1872,7 @@ def forward(
1809
1872
use_cache : Optional [bool ] = False ,
1810
1873
output_attentions : Optional [bool ] = False ,
1811
1874
output_hidden_states : Optional [bool ] = False ,
1875
+ return_dict : Optional [bool ] = False ,
1812
1876
):
1813
1877
encoder_outputs = self .encoder (
1814
1878
input_ids = input_ids ,
@@ -1819,9 +1883,25 @@ def forward(
1819
1883
use_cache = use_cache ,
1820
1884
output_attentions = output_attentions ,
1821
1885
output_hidden_states = output_hidden_states ,
1822
- )
1886
+ return_dict = return_dict )
1823
1887
1824
1888
return encoder_outputs
1825
1889
1826
1890
1827
1891
T5EncoderModel .base_model_class = T5EncoderModel
1892
+
1893
+
1894
+ def convert_encoder_output (encoder_output ):
1895
+ """
1896
+ Convert encoder_output from tuple to class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.
1897
+
1898
+ Args:
1899
+ encoder_output (tuple or ModleOutput):
1900
+ The output of the encoder, a tuple consists `last_hidden_state`, `hidden_states`(optional), `attentions`(optional).
1901
+ The data type of `last_hidden_state` is float32 and its shape is [batch_size, sequence_length, hidden_size].
1902
+ """
1903
+ return BaseModelOutput (
1904
+ last_hidden_state = encoder_output [0 ],
1905
+ hidden_states = encoder_output [1 ] if len (encoder_output ) > 1 else None ,
1906
+ attentions = encoder_output [2 ] if len (encoder_output ) > 2 else None ,
1907
+ )
0 commit comments