@@ -114,8 +114,9 @@ def _compute_cos_sin_cache(self) -> paddle.Tensor:
114
114
inv_freq = self ._compute_inv_freq (self .scaling_factor )
115
115
t = paddle .arange (self .max_position_embeddings * self .scaling_factor , dtype = paddle .float32 )
116
116
freqs = paddle .einsum ("i,j -> ij" , t , inv_freq )
117
- cos = freqs .cos () * self .mscale
118
- sin = freqs .sin () * self .mscale
117
+ emb = paddle .concat ((freqs , freqs ), axis = - 1 )
118
+ cos = emb .cos () * self .mscale
119
+ sin = emb .sin () * self .mscale
119
120
cache = paddle .concat ((cos , sin ), axis = - 1 )
120
121
return cache
121
122
@@ -125,28 +126,28 @@ def forward(
125
126
query : paddle .Tensor ,
126
127
key : paddle .Tensor ,
127
128
) -> Tuple [paddle .Tensor , paddle .Tensor ]:
128
- query_rot = query [..., : self .rotary_dim ]
129
- key_rot = key [..., : self .rotary_dim ]
129
+ q = query [..., : self .rotary_dim ]
130
+ k = key [..., : self .rotary_dim ]
130
131
if self .rotary_dim < self .head_size :
131
132
query_pass = query [..., self .rotary_dim :]
132
133
key_pass = key [..., self .rotary_dim :]
133
-
134
- cos_sin = self .cos_sin_cache [position_ids ]
134
+ cos_sin = self .cos_sin_cache [position_ids ].unsqueeze (1 )
135
135
cos , sin = cos_sin .chunk (2 , axis = - 1 )
136
136
137
- cos = cos . repeat_interleave ( 2 , axis = - 1 ). unsqueeze ( - 2 )
138
- sin = sin . repeat_interleave ( 2 , axis = - 1 ). unsqueeze ( - 2 )
137
+ s , h , d = q . shape
138
+ q = q . reshape ([ s , h , d // 2 , 2 ]). transpose ([ 0 , 1 , 3 , 2 ]). reshape ([ s , h , d ] )
139
139
140
- def _rotate_gptj (x : paddle .Tensor ) -> paddle .Tensor :
141
- x1 = x [..., ::2 ]
142
- x2 = x [..., 1 ::2 ]
143
- x = paddle .stack ((- x2 , x1 ), axis = - 1 )
144
- return x .flatten (- 2 )
140
+ s , h , d = k .shape
141
+ k = k .reshape ([s , h , d // 2 , 2 ]).transpose ([0 , 1 , 3 , 2 ]).reshape ([s , h , d ])
145
142
146
- rotate_fn = _rotate_gptj
143
+ def rotate_half (x ):
144
+ """Rotates half the hidden axiss of the input."""
145
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
146
+ x2 = x [..., x .shape [- 1 ] // 2 :]
147
+ return paddle .concat ([- x2 , x1 ], axis = - 1 ) # shape is the same as x
147
148
148
- query_rot = query_rot * cos + rotate_fn ( query_rot ) * sin
149
- key_rot = key_rot * cos + rotate_fn ( key_rot ) * sin
149
+ query_rot = ( q * cos ) + ( rotate_half ( q ) * sin )
150
+ key_rot = ( k * cos ) + ( rotate_half ( k ) * sin )
150
151
151
152
if self .rotary_dim < self .head_size :
152
153
query = paddle .concat ((query_rot , query_pass ), axis = - 1 )
0 commit comments