@@ -949,16 +949,32 @@ def forward(self,
949
949
attention_mask = None ,
950
950
encoder_hidden_states = None ,
951
951
encoder_attention_mask = None ,
952
+ inputs_embeds = None ,
952
953
cache = None ,
953
954
use_cache = False ,
954
955
output_attentions = False ,
955
956
output_hidden_states = False ,
956
957
return_dict = False ):
957
- assert input_ids is not None , "input_ids can not be None"
958
- input_shape = input_ids .shape
959
- input_ids = input_ids .reshape (shape = [- 1 , input_shape [- 1 ]])
960
958
961
- inputs_embeds = self .embed_tokens (input_ids )
959
+ if input_ids is not None and inputs_embeds is not None :
960
+ err_msg_prefix = "decoder_" if self .is_decoder else ""
961
+ raise ValueError (
962
+ f"You cannot specify both { err_msg_prefix } input_ids and { err_msg_prefix } inputs_embeds at the same time"
963
+ )
964
+ elif input_ids is not None :
965
+ input_shape = input_ids .shape
966
+ input_ids = input_ids .reshape (shape = [- 1 , input_shape [- 1 ]])
967
+ elif inputs_embeds is not None :
968
+ input_shape = inputs_embeds .shape [:- 1 ]
969
+ else :
970
+ err_msg_prefix = "decoder_" if self .is_decoder else ""
971
+ raise ValueError (
972
+ f"You have to specify either { err_msg_prefix } input_ids or { err_msg_prefix } inputs_embeds"
973
+ )
974
+
975
+ if inputs_embeds is None :
976
+ assert self .embed_tokens is not None , "You have to initialize the model with valid token embeddings"
977
+ inputs_embeds = self .embed_tokens (input_ids )
962
978
963
979
batch_size , seq_length = input_shape
964
980
@@ -1309,6 +1325,8 @@ def forward(self,
1309
1325
decoder_attention_mask = None ,
1310
1326
encoder_output = None ,
1311
1327
cache = None ,
1328
+ inputs_embeds = None ,
1329
+ decoder_inputs_embeds = None ,
1312
1330
use_cache = True ,
1313
1331
output_attentions = False ,
1314
1332
output_hidden_states = False ,
@@ -1352,6 +1370,20 @@ def forward(self,
1352
1370
The `input_ids` which have their past given to this model should not be
1353
1371
passed as input ids as they have already been computed.
1354
1372
Defaults to `None`.
1373
+ inputs_embeds (Tensor, optional):
1374
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation
1375
+ of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over
1376
+ how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1377
+ Default to None.
1378
+ decoder_inputs_embeds (Tensor, optional):
1379
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1380
+ representation of shape `(batch_size, target_sequence_length, hidden_size)`. If `cache` is used,
1381
+ optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`).
1382
+ This is useful if you want more control over how to convert `decoder_input_ids` indices
1383
+ into associated vectors than the model's internal embedding lookup matrix. Default to None.
1384
+
1385
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1386
+ of `inputs_embeds`.
1355
1387
use_cache (bool, optional):
1356
1388
Whether or not to use cache. If set to `True`, `past_buckets_states` states are returned
1357
1389
and can be used to speed up decoding.
@@ -1445,6 +1477,7 @@ def forward(self,
1445
1477
encoder_output = self .encoder (
1446
1478
input_ids = input_ids ,
1447
1479
attention_mask = attention_mask ,
1480
+ inputs_embeds = inputs_embeds ,
1448
1481
output_attentions = output_attentions ,
1449
1482
output_hidden_states = output_hidden_states ,
1450
1483
return_dict = return_dict )
@@ -1456,6 +1489,7 @@ def forward(self,
1456
1489
decoder_outputs = self .decoder (
1457
1490
input_ids = decoder_input_ids ,
1458
1491
attention_mask = decoder_attention_mask ,
1492
+ inputs_embeds = decoder_inputs_embeds ,
1459
1493
cache = cache ,
1460
1494
encoder_hidden_states = hidden_states ,
1461
1495
encoder_attention_mask = attention_mask ,
@@ -1530,6 +1564,8 @@ def forward(self,
1530
1564
encoder_output = None ,
1531
1565
cache = None ,
1532
1566
labels = None ,
1567
+ inputs_embeds = None ,
1568
+ decoder_inputs_embeds = None ,
1533
1569
use_cache = True ,
1534
1570
output_attentions = False ,
1535
1571
output_hidden_states = False ,
@@ -1555,6 +1591,20 @@ def forward(self,
1555
1591
selected in `[-100, 0, ..., vocab_size]` All labels set to `-100` are
1556
1592
ignored (masked), the loss is only computed for labels in `[0, ..., vocab_size]`.
1557
1593
Shape is [batch_size, sequence_length] and dtype is int64.
1594
+ inputs_embeds (Tensor, optional):
1595
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation
1596
+ of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over
1597
+ how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1598
+ Default to None.
1599
+ decoder_inputs_embeds (Tensor , optional):
1600
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1601
+ representation of shape `(batch_size, target_sequence_length, hidden_size)`. If `past_key_values` is used,
1602
+ optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). This is useful
1603
+ if you want more control over how to convert `decoder_input_ids` indices into associated vectors
1604
+ than the model's internal embedding lookup matrix. Default to None.
1605
+
1606
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1607
+ of `inputs_embeds`.
1558
1608
use_cache (bool, optional):
1559
1609
See :class:`T5Model`.
1560
1610
output_attentions (bool, optional):
@@ -1630,6 +1680,7 @@ def forward(self,
1630
1680
encoder_output = self .t5 .encoder (
1631
1681
input_ids = input_ids ,
1632
1682
attention_mask = attention_mask ,
1683
+ inputs_embeds = inputs_embeds ,
1633
1684
output_attentions = output_attentions ,
1634
1685
output_hidden_states = output_hidden_states ,
1635
1686
return_dict = return_dict )
@@ -1641,7 +1692,7 @@ def forward(self,
1641
1692
1642
1693
hidden_states = encoder_output [0 ]
1643
1694
1644
- if labels is not None and decoder_input_ids is None :
1695
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None :
1645
1696
# get decoder inputs from shifting lm labels to the right
1646
1697
decoder_input_ids = self ._shift_right (labels )
1647
1698
@@ -1658,6 +1709,7 @@ def forward(self,
1658
1709
decoder_outputs = self .t5 .decoder (
1659
1710
input_ids = decoder_input_ids ,
1660
1711
attention_mask = decoder_attention_mask ,
1712
+ inputs_embeds = decoder_inputs_embeds ,
1661
1713
cache = cache ,
1662
1714
encoder_hidden_states = hidden_states ,
1663
1715
encoder_attention_mask = attention_mask ,
@@ -1870,6 +1922,7 @@ def forward(
1870
1922
encoder_hidden_states : Optional [Tuple [Tensor ]] = None ,
1871
1923
encoder_attention_mask : Optional [Tensor ] = None ,
1872
1924
cache = None ,
1925
+ inputs_embeds : Optional [Tensor ] = None ,
1873
1926
use_cache : Optional [bool ] = False ,
1874
1927
output_attentions : Optional [bool ] = False ,
1875
1928
output_hidden_states : Optional [bool ] = False ,
@@ -1878,6 +1931,7 @@ def forward(
1878
1931
encoder_outputs = self .encoder (
1879
1932
input_ids = input_ids ,
1880
1933
attention_mask = attention_mask ,
1934
+ inputs_embeds = inputs_embeds ,
1881
1935
encoder_hidden_states = encoder_hidden_states ,
1882
1936
encoder_attention_mask = encoder_attention_mask ,
1883
1937
cache = cache ,
0 commit comments