Skip to content

Commit 3ca3547

Browse files
authored
Add embedding inputs to T5 model (#3668)
1 parent b63f356 commit 3ca3547

File tree

3 files changed

+103
-5
lines changed

3 files changed

+103
-5
lines changed

paddlenlp/transformers/t5/modeling.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -949,16 +949,32 @@ def forward(self,
949949
attention_mask=None,
950950
encoder_hidden_states=None,
951951
encoder_attention_mask=None,
952+
inputs_embeds=None,
952953
cache=None,
953954
use_cache=False,
954955
output_attentions=False,
955956
output_hidden_states=False,
956957
return_dict=False):
957-
assert input_ids is not None, "input_ids can not be None"
958-
input_shape = input_ids.shape
959-
input_ids = input_ids.reshape(shape=[-1, input_shape[-1]])
960958

961-
inputs_embeds = self.embed_tokens(input_ids)
959+
if input_ids is not None and inputs_embeds is not None:
960+
err_msg_prefix = "decoder_" if self.is_decoder else ""
961+
raise ValueError(
962+
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
963+
)
964+
elif input_ids is not None:
965+
input_shape = input_ids.shape
966+
input_ids = input_ids.reshape(shape=[-1, input_shape[-1]])
967+
elif inputs_embeds is not None:
968+
input_shape = inputs_embeds.shape[:-1]
969+
else:
970+
err_msg_prefix = "decoder_" if self.is_decoder else ""
971+
raise ValueError(
972+
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
973+
)
974+
975+
if inputs_embeds is None:
976+
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
977+
inputs_embeds = self.embed_tokens(input_ids)
962978

963979
batch_size, seq_length = input_shape
964980

@@ -1309,6 +1325,8 @@ def forward(self,
13091325
decoder_attention_mask=None,
13101326
encoder_output=None,
13111327
cache=None,
1328+
inputs_embeds=None,
1329+
decoder_inputs_embeds=None,
13121330
use_cache=True,
13131331
output_attentions=False,
13141332
output_hidden_states=False,
@@ -1352,6 +1370,20 @@ def forward(self,
13521370
The `input_ids` which have their past given to this model should not be
13531371
passed as input ids as they have already been computed.
13541372
Defaults to `None`.
1373+
inputs_embeds (Tensor, optional):
1374+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation
1375+
of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over
1376+
how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1377+
Default to None.
1378+
decoder_inputs_embeds (Tensor, optional):
1379+
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1380+
representation of shape `(batch_size, target_sequence_length, hidden_size)`. If `cache` is used,
1381+
optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`).
1382+
This is useful if you want more control over how to convert `decoder_input_ids` indices
1383+
into associated vectors than the model's internal embedding lookup matrix. Default to None.
1384+
1385+
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1386+
of `inputs_embeds`.
13551387
use_cache (bool, optional):
13561388
Whether or not to use cache. If set to `True`, `past_buckets_states` states are returned
13571389
and can be used to speed up decoding.
@@ -1445,6 +1477,7 @@ def forward(self,
14451477
encoder_output = self.encoder(
14461478
input_ids=input_ids,
14471479
attention_mask=attention_mask,
1480+
inputs_embeds=inputs_embeds,
14481481
output_attentions=output_attentions,
14491482
output_hidden_states=output_hidden_states,
14501483
return_dict=return_dict)
@@ -1456,6 +1489,7 @@ def forward(self,
14561489
decoder_outputs = self.decoder(
14571490
input_ids=decoder_input_ids,
14581491
attention_mask=decoder_attention_mask,
1492+
inputs_embeds=decoder_inputs_embeds,
14591493
cache=cache,
14601494
encoder_hidden_states=hidden_states,
14611495
encoder_attention_mask=attention_mask,
@@ -1530,6 +1564,8 @@ def forward(self,
15301564
encoder_output=None,
15311565
cache=None,
15321566
labels=None,
1567+
inputs_embeds=None,
1568+
decoder_inputs_embeds=None,
15331569
use_cache=True,
15341570
output_attentions=False,
15351571
output_hidden_states=False,
@@ -1555,6 +1591,20 @@ def forward(self,
15551591
selected in `[-100, 0, ..., vocab_size]` All labels set to `-100` are
15561592
ignored (masked), the loss is only computed for labels in `[0, ..., vocab_size]`.
15571593
Shape is [batch_size, sequence_length] and dtype is int64.
1594+
inputs_embeds (Tensor, optional):
1595+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation
1596+
of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over
1597+
how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1598+
Default to None.
1599+
decoder_inputs_embeds (Tensor , optional):
1600+
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1601+
representation of shape `(batch_size, target_sequence_length, hidden_size)`. If `past_key_values` is used,
1602+
optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). This is useful
1603+
if you want more control over how to convert `decoder_input_ids` indices into associated vectors
1604+
than the model's internal embedding lookup matrix. Default to None.
1605+
1606+
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1607+
of `inputs_embeds`.
15581608
use_cache (bool, optional):
15591609
See :class:`T5Model`.
15601610
output_attentions (bool, optional):
@@ -1630,6 +1680,7 @@ def forward(self,
16301680
encoder_output = self.t5.encoder(
16311681
input_ids=input_ids,
16321682
attention_mask=attention_mask,
1683+
inputs_embeds=inputs_embeds,
16331684
output_attentions=output_attentions,
16341685
output_hidden_states=output_hidden_states,
16351686
return_dict=return_dict)
@@ -1641,7 +1692,7 @@ def forward(self,
16411692

16421693
hidden_states = encoder_output[0]
16431694

1644-
if labels is not None and decoder_input_ids is None:
1695+
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
16451696
# get decoder inputs from shifting lm labels to the right
16461697
decoder_input_ids = self._shift_right(labels)
16471698

@@ -1658,6 +1709,7 @@ def forward(self,
16581709
decoder_outputs = self.t5.decoder(
16591710
input_ids=decoder_input_ids,
16601711
attention_mask=decoder_attention_mask,
1712+
inputs_embeds=decoder_inputs_embeds,
16611713
cache=cache,
16621714
encoder_hidden_states=hidden_states,
16631715
encoder_attention_mask=attention_mask,
@@ -1870,6 +1922,7 @@ def forward(
18701922
encoder_hidden_states: Optional[Tuple[Tensor]] = None,
18711923
encoder_attention_mask: Optional[Tensor] = None,
18721924
cache=None,
1925+
inputs_embeds: Optional[Tensor] = None,
18731926
use_cache: Optional[bool] = False,
18741927
output_attentions: Optional[bool] = False,
18751928
output_hidden_states: Optional[bool] = False,
@@ -1878,6 +1931,7 @@ def forward(
18781931
encoder_outputs = self.encoder(
18791932
input_ids=input_ids,
18801933
attention_mask=attention_mask,
1934+
inputs_embeds=inputs_embeds,
18811935
encoder_hidden_states=encoder_hidden_states,
18821936
encoder_attention_mask=encoder_attention_mask,
18831937
cache=cache,

tests/transformers/t5/test_modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
537537
test_pruning = False
538538
test_resize_embeddings = True
539539
test_model_parallel = True
540+
use_test_inputs_embeds = True
540541
is_encoder_decoder = True
541542
# The small T5 model needs higher percentages for CPU/MP tests
542543
model_split_percents = [0.8, 0.9]

tests/transformers/test_modeling_common.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class ModelTesterMixin:
6767
test_resize_position_embeddings = False
6868
test_mismatched_shapes = True
6969
test_missing_keys = True
70+
use_test_inputs_embeds = False
7071
use_test_model_name_list = True
7172
is_encoder_decoder = False
7273
has_attentions = True
@@ -508,6 +509,48 @@ def test_resize_tokens_embeddings(self):
508509

509510
self.assertTrue(models_equal)
510511

512+
def test_inputs_embeds(self):
513+
# pass the test if don't need to test inputs embeddings
514+
if not self.use_test_inputs_embeds:
515+
return
516+
# get config for model and inputs_dict for model forward
517+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
518+
)
519+
# test all model classes
520+
for model_class in self.all_model_classes:
521+
model = self._make_model_instance(config, model_class)
522+
model.eval()
523+
524+
inputs = copy.deepcopy(
525+
self._prepare_for_class(inputs_dict, model_class))
526+
527+
with paddle.no_grad():
528+
ids_output = model(**inputs)
529+
530+
if not self.is_encoder_decoder:
531+
input_ids = inputs["input_ids"]
532+
del inputs["input_ids"]
533+
else:
534+
encoder_input_ids = inputs["input_ids"]
535+
decoder_input_ids = inputs.get("decoder_input_ids",
536+
encoder_input_ids)
537+
del inputs["input_ids"]
538+
inputs.pop("decoder_input_ids", None)
539+
540+
wte = model.get_input_embeddings()
541+
if not self.is_encoder_decoder:
542+
inputs["inputs_embeds"] = wte(input_ids)
543+
else:
544+
inputs["inputs_embeds"] = wte(encoder_input_ids)
545+
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
546+
547+
with paddle.no_grad():
548+
embeds_output = model(**inputs)
549+
550+
self.assertTrue(
551+
paddle.allclose(ids_output, embeds_output, rtol=1e-4,
552+
atol=1e-4))
553+
511554
def test_model_name_list(self):
512555
if not self.use_test_model_name_list:
513556
return

0 commit comments

Comments
 (0)