Skip to content

Commit 2a61982

Browse files
committed
Remove caching logic for local & tglobal attention
1 parent 62d4462 commit 2a61982

File tree

2 files changed

+33
-135
lines changed

2 files changed

+33
-135
lines changed

src/transformers/models/longt5/modeling_flax_longt5.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
import copy
19-
from typing import Callable, List, Optional, Tuple
19+
from typing import Any, Callable, List, Optional, Tuple
2020

2121
import numpy as np
2222

@@ -730,10 +730,8 @@ def __call__(
730730
attention_mask=None,
731731
key_value_states=None,
732732
position_bias=None,
733-
use_cache=False,
734733
output_attentions=False,
735734
deterministic=True,
736-
init_cache=False,
737735
):
738736
"""
739737
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
@@ -867,13 +865,12 @@ def setup(self):
867865
embedding_init=jax.nn.initializers.normal(kv_init_std),
868866
)
869867

870-
# Relativen attention bias & Layer norm for global attention
871-
if self.has_relative_attention_bias:
872-
self.global_relative_attention_bias = nn.Embed(
873-
self.relative_attention_num_buckets,
874-
self.n_heads,
875-
embedding_init=jax.nn.initializers.normal(kv_init_std),
876-
)
868+
# Relative attention bias & Layer norm for global attention - global relative attention bias is always applied
869+
self.global_relative_attention_bias = nn.Embed(
870+
self.relative_attention_num_buckets,
871+
self.n_heads,
872+
embedding_init=jax.nn.initializers.normal(kv_init_std),
873+
)
877874
self.global_input_layer_norm = FlaxLongT5LayerNorm(
878875
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
879876
)
@@ -980,10 +977,8 @@ def __call__(
980977
attention_mask=None,
981978
key_value_states=None,
982979
position_bias=None,
983-
use_cache=False,
984980
output_attentions=False,
985981
deterministic=True,
986-
init_cache=False,
987982
):
988983
"""
989984
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
@@ -1127,7 +1122,7 @@ def __call__(
11271122
position_bias=None,
11281123
output_attentions=False,
11291124
deterministic=True,
1130-
init_cache=False,
1125+
**kwargs: Any, # to accept init_cache kwargs
11311126
):
11321127
normed_hidden_states = self.layer_norm(hidden_states)
11331128
attention_output = self.LocalSelfAttention(
@@ -1136,7 +1131,6 @@ def __call__(
11361131
position_bias=position_bias,
11371132
output_attentions=output_attentions,
11381133
deterministic=deterministic,
1139-
init_cache=init_cache,
11401134
)
11411135
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
11421136
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
@@ -1166,7 +1160,7 @@ def __call__(
11661160
position_bias=None,
11671161
output_attentions=False,
11681162
deterministic=True,
1169-
init_cache=False,
1163+
**kwargs: Any, # to accept init_cache kwargs
11701164
):
11711165
normed_hidden_states = self.layer_norm(hidden_states)
11721166
attention_output = self.TransientGlobalSelfAttention(
@@ -1175,7 +1169,6 @@ def __call__(
11751169
position_bias=position_bias,
11761170
output_attentions=output_attentions,
11771171
deterministic=deterministic,
1178-
init_cache=init_cache,
11791172
)
11801173
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
11811174
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them

src/transformers/models/longt5/modeling_longt5.py

Lines changed: 24 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import copy
1919
import math
2020
import warnings
21-
from typing import List, Optional, Tuple, Union
21+
from typing import Any, List, Optional, Tuple, Union
2222

2323
import torch
2424
from torch import nn
@@ -656,22 +656,11 @@ def forward(
656656
self,
657657
hidden_states,
658658
mask=None,
659-
key_value_states=None,
660659
position_bias=None,
661-
past_key_value=None,
662660
layer_head_mask=None,
663-
query_length=None,
664-
use_cache=False,
665661
output_attentions=False,
666662
):
667663
batch_size, seq_length = hidden_states.shape[:2]
668-
real_seq_length = seq_length
669-
670-
if past_key_value is not None:
671-
assert (
672-
len(past_key_value) == 2
673-
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
674-
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
675664

676665
def shape(states):
677666
"""projection"""
@@ -681,37 +670,10 @@ def unshape(states):
681670
"""reshape"""
682671
return states.contiguous().view(batch_size, -1, self.inner_dim)
683672

684-
def project(hidden_states, proj_layer, key_value_states, past_key_value):
685-
"""projects hidden states correctly to key/query states"""
686-
if key_value_states is None:
687-
# self-attn
688-
# (batch_size, seq_length, n_heads, dim_per_head)
689-
hidden_states = shape(proj_layer(hidden_states))
690-
elif past_key_value is None:
691-
# cross-attn
692-
# (batch_size, seq_length, n_heads, dim_per_head)
693-
hidden_states = shape(proj_layer(key_value_states))
694-
695-
if past_key_value is not None:
696-
if key_value_states is None:
697-
# self-attn
698-
# (batch_size, seq_length, n_heads, dim_per_head)
699-
hidden_states = torch.cat([past_key_value.transpose(1, 2), hidden_states], dim=2)
700-
else:
701-
# cross-attn
702-
hidden_states = past_key_value.transpose(1, 2)
703-
return hidden_states
704-
705-
# get query states -> (batch_size, seq_length, n_heads, dim_per_head)
673+
# get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
706674
query_states = shape(self.q(hidden_states))
707-
708-
# get key/value states
709-
key_states = project(
710-
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
711-
)
712-
value_states = project(
713-
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
714-
)
675+
key_states = shape(self.k(hidden_states))
676+
value_states = shape(self.v(hidden_states))
715677

716678
# Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
717679
query_states = _split_into_blocks(query_states, self.block_len, dim=1)
@@ -722,10 +684,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
722684
key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
723685
value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
724686

725-
# Compute scores
726-
scores = torch.einsum(
727-
"...qhd,...khd->...hqk", query_states, key_states
728-
) # (batch_size, num_block, n_heads, block_len, 3 * block_len)
687+
# Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len)
688+
scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states)
729689

730690
if position_bias is None:
731691
# position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
@@ -737,10 +697,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
737697
position_bias.requires_grad = True
738698
else:
739699
position_bias = self.compute_bias(self.block_len)
740-
# if key and values are already calculated
741-
# we want only the last query position bias
742-
if past_key_value is not None:
743-
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
744700

745701
if mask is not None:
746702
# Replace masked positions with -10_000 (according to the original implementation)
@@ -762,8 +718,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
762718
attn_output = attn_output[:, :seq_length, :]
763719
attn_output = self.o(attn_output)
764720

765-
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
766-
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
721+
outputs = (attn_output,) + (position_bias,)
767722

768723
if output_attentions:
769724
outputs = outputs + (attn_weights,)
@@ -797,9 +752,8 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = Fal
797752
self.pruned_heads = set()
798753
self.gradient_checkpointing = False
799754

800-
# Relativen attention bias & Layer norm for global attention
801-
if self.has_relative_attention_bias:
802-
self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
755+
# Relative attention bias & Layer norm for global attention - global relative attention bias is always applied
756+
self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
803757
self.global_input_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
804758

805759
# Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
@@ -879,7 +833,7 @@ def compute_bias(self, block_length: int):
879833
# (block_length, 3 * block_length)
880834
relative_position = memory_position - context_position
881835
relative_position_bucket = self._relative_position_bucket(
882-
relative_position, # (block_length, 3 * block_length)
836+
relative_position,
883837
bidirectional=(not self.is_decoder),
884838
num_buckets=self.relative_attention_num_buckets,
885839
max_distance=self.relative_attention_max_distance,
@@ -915,22 +869,11 @@ def forward(
915869
self,
916870
hidden_states,
917871
mask=None,
918-
key_value_states=None,
919872
position_bias=None,
920-
past_key_value=None,
921873
layer_head_mask=None,
922-
query_length=None,
923-
use_cache=False,
924874
output_attentions=False,
925875
):
926876
batch_size, seq_length = hidden_states.shape[:2]
927-
real_seq_length = seq_length
928-
929-
if past_key_value is not None:
930-
assert (
931-
len(past_key_value) == 2
932-
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
933-
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
934877

935878
def shape(states):
936879
"""projection"""
@@ -940,27 +883,6 @@ def unshape(states):
940883
"""reshape"""
941884
return states.contiguous().view(batch_size, -1, self.inner_dim)
942885

943-
def project(hidden_states, proj_layer, key_value_states, past_key_value):
944-
"""projects hidden states correctly to key/query states"""
945-
if key_value_states is None:
946-
# self-attn
947-
# (batch_size, seq_length, n_heads, dim_per_head)
948-
hidden_states = shape(proj_layer(hidden_states))
949-
elif past_key_value is None:
950-
# cross-attn
951-
# (batch_size, seq_length, n_heads, dim_per_head)
952-
hidden_states = shape(proj_layer(key_value_states))
953-
954-
if past_key_value is not None:
955-
if key_value_states is None:
956-
# self-attn
957-
# (batch_size, seq_length, n_heads, dim_per_head)
958-
hidden_states = torch.cat([past_key_value.transpose(1, 2), hidden_states], dim=2)
959-
else:
960-
# cross-attn
961-
hidden_states = past_key_value.transpose(1, 2)
962-
return hidden_states
963-
964886
# Prepare components for transient-global attention
965887
# Obtain block_ids and global_segment_ids
966888
# global_seq_len := seq_len // self.global_block_size
@@ -974,20 +896,14 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
974896
global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
975897
global_inputs = self.global_input_layer_norm(global_inputs)
976898

977-
# get query states -> (batch_size, seq_length, n_heads, dim_per_head)
899+
# get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
978900
query_states = shape(self.q(hidden_states))
979-
980-
# get key/value states
981-
key_states = project(
982-
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
983-
)
984-
value_states = project(
985-
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
986-
)
901+
key_states = shape(self.k(hidden_states))
902+
value_states = shape(self.v(hidden_states))
987903

988904
# Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
989-
side_key_states = project(global_inputs, self.k, None, None)
990-
side_value_states = project(global_inputs, self.v, None, None)
905+
side_key_states = shape(self.k(global_inputs))
906+
side_value_states = shape(self.v(global_inputs))
991907

992908
# Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
993909
query_states = _split_into_blocks(query_states, self.block_len, dim=1)
@@ -1033,10 +949,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
1033949
position_bias.requires_grad = True
1034950
else:
1035951
position_bias = self.compute_bias(self.block_len)
1036-
# if key and values are already calculated
1037-
# we want only the last query position bias
1038-
if past_key_value is not None:
1039-
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
1040952

1041953
if local_attention_mask is not None:
1042954
# (batch_size, 1, n_heads, block_len, 3 * block_len)
@@ -1065,8 +977,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
1065977
attn_output = attn_output[:, :seq_length, :]
1066978
attn_output = self.o(attn_output)
1067979

1068-
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
1069-
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
980+
outputs = (attn_output,) + (position_bias,)
1070981

1071982
if output_attentions:
1072983
outputs = outputs + (attn_weights,)
@@ -1121,18 +1032,15 @@ def forward(
11211032
attention_mask=None,
11221033
position_bias=None,
11231034
layer_head_mask=None,
1124-
past_key_value=None,
1125-
use_cache=False,
11261035
output_attentions=False,
1036+
**kwargs: Any, # to accept past_key_value and use_cache kwargs
11271037
):
11281038
normed_hidden_states = self.layer_norm(hidden_states)
11291039
attention_output = self.LocalSelfAttention(
11301040
normed_hidden_states,
11311041
mask=attention_mask,
11321042
position_bias=position_bias,
11331043
layer_head_mask=layer_head_mask,
1134-
past_key_value=past_key_value,
1135-
use_cache=use_cache,
11361044
output_attentions=output_attentions,
11371045
)
11381046
hidden_states = hidden_states + self.dropout(attention_output[0])
@@ -1157,18 +1065,15 @@ def forward(
11571065
attention_mask=None,
11581066
position_bias=None,
11591067
layer_head_mask=None,
1160-
past_key_value=None,
1161-
use_cache=False,
11621068
output_attentions=False,
1069+
**kwargs: Any, # to accept past_key_value and use_cache kwargs
11631070
):
11641071
normed_hidden_states = self.layer_norm(hidden_states)
11651072
attention_output = self.TransientGlobalSelfAttention(
11661073
normed_hidden_states,
11671074
mask=attention_mask,
11681075
position_bias=position_bias,
11691076
layer_head_mask=layer_head_mask,
1170-
past_key_value=past_key_value,
1171-
use_cache=use_cache,
11721077
output_attentions=output_attentions,
11731078
)
11741079
hidden_states = hidden_states + self.dropout(attention_output[0])
@@ -1402,10 +1307,8 @@ def _init_weights(self, module):
14021307
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
14031308
if module.has_relative_attention_bias:
14041309
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
1405-
if isinstance(module, LongT5TransientGlobalAttention):
1406-
module.global_relative_attention_bias.weight.data.normal_(
1407-
mean=0.0, std=factor * ((d_model) ** -0.5)
1408-
)
1310+
if isinstance(module, LongT5TransientGlobalAttention):
1311+
module.global_relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
14091312

14101313
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5
14111314
def _set_gradient_checkpointing(self, module, value=False):
@@ -1644,17 +1547,19 @@ def custom_forward(*inputs):
16441547
# We share the position biases between the layers - the first layer store them
16451548
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
16461549
# (cross-attention position bias), (cross-attention weights)
1647-
position_bias = layer_outputs[2]
1550+
position_bias = layer_outputs[2] if self.is_decoder else layer_outputs[1]
16481551
if self.is_decoder and encoder_hidden_states is not None:
16491552
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
16501553
# append next layer key value states
16511554
if use_cache:
16521555
present_key_value_states = present_key_value_states + (present_key_value_state,)
16531556

16541557
if output_attentions:
1655-
all_attentions = all_attentions + (layer_outputs[3],)
16561558
if self.is_decoder:
1559+
all_attentions = all_attentions + (layer_outputs[3],)
16571560
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1561+
else:
1562+
all_attentions = all_attentions + (layer_outputs[2],)
16581563

16591564
# Model Parallel: If it's the last layer for that device, put things on the next device
16601565
if self.model_parallel:

0 commit comments

Comments
 (0)