18
18
import copy
19
19
import math
20
20
import warnings
21
- from typing import List , Optional , Tuple , Union
21
+ from typing import Any , List , Optional , Tuple , Union
22
22
23
23
import torch
24
24
from torch import nn
@@ -656,22 +656,11 @@ def forward(
656
656
self ,
657
657
hidden_states ,
658
658
mask = None ,
659
- key_value_states = None ,
660
659
position_bias = None ,
661
- past_key_value = None ,
662
660
layer_head_mask = None ,
663
- query_length = None ,
664
- use_cache = False ,
665
661
output_attentions = False ,
666
662
):
667
663
batch_size , seq_length = hidden_states .shape [:2 ]
668
- real_seq_length = seq_length
669
-
670
- if past_key_value is not None :
671
- assert (
672
- len (past_key_value ) == 2
673
- ), f"past_key_value should have 2 past states: keys and values. Got { len (past_key_value )} past states"
674
- real_seq_length += past_key_value [0 ].shape [2 ] if query_length is None else query_length
675
664
676
665
def shape (states ):
677
666
"""projection"""
@@ -681,37 +670,10 @@ def unshape(states):
681
670
"""reshape"""
682
671
return states .contiguous ().view (batch_size , - 1 , self .inner_dim )
683
672
684
- def project (hidden_states , proj_layer , key_value_states , past_key_value ):
685
- """projects hidden states correctly to key/query states"""
686
- if key_value_states is None :
687
- # self-attn
688
- # (batch_size, seq_length, n_heads, dim_per_head)
689
- hidden_states = shape (proj_layer (hidden_states ))
690
- elif past_key_value is None :
691
- # cross-attn
692
- # (batch_size, seq_length, n_heads, dim_per_head)
693
- hidden_states = shape (proj_layer (key_value_states ))
694
-
695
- if past_key_value is not None :
696
- if key_value_states is None :
697
- # self-attn
698
- # (batch_size, seq_length, n_heads, dim_per_head)
699
- hidden_states = torch .cat ([past_key_value .transpose (1 , 2 ), hidden_states ], dim = 2 )
700
- else :
701
- # cross-attn
702
- hidden_states = past_key_value .transpose (1 , 2 )
703
- return hidden_states
704
-
705
- # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
673
+ # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
706
674
query_states = shape (self .q (hidden_states ))
707
-
708
- # get key/value states
709
- key_states = project (
710
- hidden_states , self .k , key_value_states , past_key_value [0 ] if past_key_value is not None else None
711
- )
712
- value_states = project (
713
- hidden_states , self .v , key_value_states , past_key_value [1 ] if past_key_value is not None else None
714
- )
675
+ key_states = shape (self .k (hidden_states ))
676
+ value_states = shape (self .v (hidden_states ))
715
677
716
678
# Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
717
679
query_states = _split_into_blocks (query_states , self .block_len , dim = 1 )
@@ -722,10 +684,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
722
684
key_states = _concatenate_3_blocks (key_states , block_dim = 1 , sequence_dim = 2 )
723
685
value_states = _concatenate_3_blocks (value_states , block_dim = 1 , sequence_dim = 2 )
724
686
725
- # Compute scores
726
- scores = torch .einsum (
727
- "...qhd,...khd->...hqk" , query_states , key_states
728
- ) # (batch_size, num_block, n_heads, block_len, 3 * block_len)
687
+ # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len)
688
+ scores = torch .einsum ("...qhd,...khd->...hqk" , query_states , key_states )
729
689
730
690
if position_bias is None :
731
691
# position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
@@ -737,10 +697,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
737
697
position_bias .requires_grad = True
738
698
else :
739
699
position_bias = self .compute_bias (self .block_len )
740
- # if key and values are already calculated
741
- # we want only the last query position bias
742
- if past_key_value is not None :
743
- position_bias = position_bias [:, :, - hidden_states .size (1 ) :, :]
744
700
745
701
if mask is not None :
746
702
# Replace masked positions with -10_000 (according to the original implementation)
@@ -762,8 +718,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
762
718
attn_output = attn_output [:, :seq_length , :]
763
719
attn_output = self .o (attn_output )
764
720
765
- present_key_value_state = (key_states , value_states ) if (self .is_decoder and use_cache ) else None
766
- outputs = (attn_output ,) + (present_key_value_state ,) + (position_bias ,)
721
+ outputs = (attn_output ,) + (position_bias ,)
767
722
768
723
if output_attentions :
769
724
outputs = outputs + (attn_weights ,)
@@ -797,9 +752,8 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = Fal
797
752
self .pruned_heads = set ()
798
753
self .gradient_checkpointing = False
799
754
800
- # Relativen attention bias & Layer norm for global attention
801
- if self .has_relative_attention_bias :
802
- self .global_relative_attention_bias = nn .Embedding (self .relative_attention_num_buckets , self .n_heads )
755
+ # Relative attention bias & Layer norm for global attention - global relative attention bias is always applied
756
+ self .global_relative_attention_bias = nn .Embedding (self .relative_attention_num_buckets , self .n_heads )
803
757
self .global_input_layer_norm = LongT5LayerNorm (config .d_model , eps = config .layer_norm_epsilon )
804
758
805
759
# Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
@@ -879,7 +833,7 @@ def compute_bias(self, block_length: int):
879
833
# (block_length, 3 * block_length)
880
834
relative_position = memory_position - context_position
881
835
relative_position_bucket = self ._relative_position_bucket (
882
- relative_position , # (block_length, 3 * block_length)
836
+ relative_position ,
883
837
bidirectional = (not self .is_decoder ),
884
838
num_buckets = self .relative_attention_num_buckets ,
885
839
max_distance = self .relative_attention_max_distance ,
@@ -915,22 +869,11 @@ def forward(
915
869
self ,
916
870
hidden_states ,
917
871
mask = None ,
918
- key_value_states = None ,
919
872
position_bias = None ,
920
- past_key_value = None ,
921
873
layer_head_mask = None ,
922
- query_length = None ,
923
- use_cache = False ,
924
874
output_attentions = False ,
925
875
):
926
876
batch_size , seq_length = hidden_states .shape [:2 ]
927
- real_seq_length = seq_length
928
-
929
- if past_key_value is not None :
930
- assert (
931
- len (past_key_value ) == 2
932
- ), f"past_key_value should have 2 past states: keys and values. Got { len (past_key_value )} past states"
933
- real_seq_length += past_key_value [0 ].shape [2 ] if query_length is None else query_length
934
877
935
878
def shape (states ):
936
879
"""projection"""
@@ -940,27 +883,6 @@ def unshape(states):
940
883
"""reshape"""
941
884
return states .contiguous ().view (batch_size , - 1 , self .inner_dim )
942
885
943
- def project (hidden_states , proj_layer , key_value_states , past_key_value ):
944
- """projects hidden states correctly to key/query states"""
945
- if key_value_states is None :
946
- # self-attn
947
- # (batch_size, seq_length, n_heads, dim_per_head)
948
- hidden_states = shape (proj_layer (hidden_states ))
949
- elif past_key_value is None :
950
- # cross-attn
951
- # (batch_size, seq_length, n_heads, dim_per_head)
952
- hidden_states = shape (proj_layer (key_value_states ))
953
-
954
- if past_key_value is not None :
955
- if key_value_states is None :
956
- # self-attn
957
- # (batch_size, seq_length, n_heads, dim_per_head)
958
- hidden_states = torch .cat ([past_key_value .transpose (1 , 2 ), hidden_states ], dim = 2 )
959
- else :
960
- # cross-attn
961
- hidden_states = past_key_value .transpose (1 , 2 )
962
- return hidden_states
963
-
964
886
# Prepare components for transient-global attention
965
887
# Obtain block_ids and global_segment_ids
966
888
# global_seq_len := seq_len // self.global_block_size
@@ -974,20 +896,14 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
974
896
global_inputs = _create_global_aggregates (hidden_states , block_ids , _global_seq_len )
975
897
global_inputs = self .global_input_layer_norm (global_inputs )
976
898
977
- # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
899
+ # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
978
900
query_states = shape (self .q (hidden_states ))
979
-
980
- # get key/value states
981
- key_states = project (
982
- hidden_states , self .k , key_value_states , past_key_value [0 ] if past_key_value is not None else None
983
- )
984
- value_states = project (
985
- hidden_states , self .v , key_value_states , past_key_value [1 ] if past_key_value is not None else None
986
- )
901
+ key_states = shape (self .k (hidden_states ))
902
+ value_states = shape (self .v (hidden_states ))
987
903
988
904
# Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
989
- side_key_states = project ( global_inputs , self .k , None , None )
990
- side_value_states = project ( global_inputs , self .v , None , None )
905
+ side_key_states = shape ( self .k ( global_inputs ) )
906
+ side_value_states = shape ( self .v ( global_inputs ) )
991
907
992
908
# Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
993
909
query_states = _split_into_blocks (query_states , self .block_len , dim = 1 )
@@ -1033,10 +949,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
1033
949
position_bias .requires_grad = True
1034
950
else :
1035
951
position_bias = self .compute_bias (self .block_len )
1036
- # if key and values are already calculated
1037
- # we want only the last query position bias
1038
- if past_key_value is not None :
1039
- position_bias = position_bias [:, :, - hidden_states .size (1 ) :, :]
1040
952
1041
953
if local_attention_mask is not None :
1042
954
# (batch_size, 1, n_heads, block_len, 3 * block_len)
@@ -1065,8 +977,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
1065
977
attn_output = attn_output [:, :seq_length , :]
1066
978
attn_output = self .o (attn_output )
1067
979
1068
- present_key_value_state = (key_states , value_states ) if (self .is_decoder and use_cache ) else None
1069
- outputs = (attn_output ,) + (present_key_value_state ,) + (position_bias ,)
980
+ outputs = (attn_output ,) + (position_bias ,)
1070
981
1071
982
if output_attentions :
1072
983
outputs = outputs + (attn_weights ,)
@@ -1121,18 +1032,15 @@ def forward(
1121
1032
attention_mask = None ,
1122
1033
position_bias = None ,
1123
1034
layer_head_mask = None ,
1124
- past_key_value = None ,
1125
- use_cache = False ,
1126
1035
output_attentions = False ,
1036
+ ** kwargs : Any , # to accept past_key_value and use_cache kwargs
1127
1037
):
1128
1038
normed_hidden_states = self .layer_norm (hidden_states )
1129
1039
attention_output = self .LocalSelfAttention (
1130
1040
normed_hidden_states ,
1131
1041
mask = attention_mask ,
1132
1042
position_bias = position_bias ,
1133
1043
layer_head_mask = layer_head_mask ,
1134
- past_key_value = past_key_value ,
1135
- use_cache = use_cache ,
1136
1044
output_attentions = output_attentions ,
1137
1045
)
1138
1046
hidden_states = hidden_states + self .dropout (attention_output [0 ])
@@ -1157,18 +1065,15 @@ def forward(
1157
1065
attention_mask = None ,
1158
1066
position_bias = None ,
1159
1067
layer_head_mask = None ,
1160
- past_key_value = None ,
1161
- use_cache = False ,
1162
1068
output_attentions = False ,
1069
+ ** kwargs : Any , # to accept past_key_value and use_cache kwargs
1163
1070
):
1164
1071
normed_hidden_states = self .layer_norm (hidden_states )
1165
1072
attention_output = self .TransientGlobalSelfAttention (
1166
1073
normed_hidden_states ,
1167
1074
mask = attention_mask ,
1168
1075
position_bias = position_bias ,
1169
1076
layer_head_mask = layer_head_mask ,
1170
- past_key_value = past_key_value ,
1171
- use_cache = use_cache ,
1172
1077
output_attentions = output_attentions ,
1173
1078
)
1174
1079
hidden_states = hidden_states + self .dropout (attention_output [0 ])
@@ -1402,10 +1307,8 @@ def _init_weights(self, module):
1402
1307
module .o .weight .data .normal_ (mean = 0.0 , std = factor * ((n_heads * key_value_proj_dim ) ** - 0.5 ))
1403
1308
if module .has_relative_attention_bias :
1404
1309
module .relative_attention_bias .weight .data .normal_ (mean = 0.0 , std = factor * ((d_model ) ** - 0.5 ))
1405
- if isinstance (module , LongT5TransientGlobalAttention ):
1406
- module .global_relative_attention_bias .weight .data .normal_ (
1407
- mean = 0.0 , std = factor * ((d_model ) ** - 0.5 )
1408
- )
1310
+ if isinstance (module , LongT5TransientGlobalAttention ):
1311
+ module .global_relative_attention_bias .weight .data .normal_ (mean = 0.0 , std = factor * ((d_model ) ** - 0.5 ))
1409
1312
1410
1313
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5
1411
1314
def _set_gradient_checkpointing (self , module , value = False ):
@@ -1644,17 +1547,19 @@ def custom_forward(*inputs):
1644
1547
# We share the position biases between the layers - the first layer store them
1645
1548
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1646
1549
# (cross-attention position bias), (cross-attention weights)
1647
- position_bias = layer_outputs [2 ]
1550
+ position_bias = layer_outputs [2 ] if self . is_decoder else layer_outputs [ 1 ]
1648
1551
if self .is_decoder and encoder_hidden_states is not None :
1649
1552
encoder_decoder_position_bias = layer_outputs [4 if output_attentions else 3 ]
1650
1553
# append next layer key value states
1651
1554
if use_cache :
1652
1555
present_key_value_states = present_key_value_states + (present_key_value_state ,)
1653
1556
1654
1557
if output_attentions :
1655
- all_attentions = all_attentions + (layer_outputs [3 ],)
1656
1558
if self .is_decoder :
1559
+ all_attentions = all_attentions + (layer_outputs [3 ],)
1657
1560
all_cross_attentions = all_cross_attentions + (layer_outputs [5 ],)
1561
+ else :
1562
+ all_attentions = all_attentions + (layer_outputs [2 ],)
1658
1563
1659
1564
# Model Parallel: If it's the last layer for that device, put things on the next device
1660
1565
if self .model_parallel :
0 commit comments