@@ -267,6 +267,77 @@ def backward(ctx, dout):
267
267
dweight = kitchen_fp8_gemm (x_t_quant , x_t_scale , dout_t_quant , dout_t_scale , True , True )
268
268
return dx , dweight
269
269
270
+ class LinearFP8KeepXFunc (paddle .autograd .PyLayer ):
271
+ @staticmethod
272
+ def forward (ctx , x , weight ):
273
+ x_orig_shape = x .shape
274
+ # deep_gemm only support 2D
275
+ x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
276
+ # quant
277
+ x_quant , x_scale = kitchen_quant (
278
+ x , backend = kitchen .ops .Backend .CUTLASS , is_1d_scaled = True , return_transpose = False
279
+ )
280
+ weight_t = weight .T .contiguous ()
281
+ w_quant , w_scale = kitchen_quant (
282
+ weight_t , backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = False , return_transpose = False
283
+ )
284
+
285
+ # compute out = mm(x, w_t)
286
+ out = paddle .empty ([x .shape [0 ], weight .shape [- 1 ]], dtype = x .dtype )
287
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_quant , x_scale ), (w_quant , w_scale ), out )
288
+ out = out .reshape ([x_orig_shape [0 ], - 1 , weight .shape [- 1 ]])
289
+
290
+
291
+ ctx .save_for_backward (
292
+ x , weight
293
+ )
294
+ return out
295
+
296
+ @staticmethod
297
+ def backward (ctx , dout ):
298
+ x , weight = ctx .saved_tensor ()
299
+
300
+ # padding
301
+ x_t = x .T .contiguous ()
302
+ if x_t .shape [- 1 ] % 8 != 0 :
303
+ x_t = paddle .concat ([x_t , paddle .zeros ([x_t .shape [0 ], 8 - (x_t .shape [- 1 ] % 8 )], dtype = x_t .dtype )], axis = - 1 )
304
+ x_t_quant , x_t_scale = kitchen_quant (
305
+ x_t , backend = kitchen .ops .Backend .CUTLASS , is_1d_scaled = True , return_transpose = False
306
+ )
307
+
308
+
309
+ x_t_shape = x_t_shape .numpy ()
310
+ # compute dx = mm(dout, w)
311
+ dx = paddle .empty (x .shape , dout .dtype )
312
+ dx_orig_shape = x .shape
313
+
314
+ dout_quant , dout_scale = kitchen_quant (
315
+ dout .reshape ([- 1 , dout .shape [- 1 ]]),
316
+ backend = kitchen .ops .Backend .CUTLASS ,
317
+ is_1d_scaled = True ,
318
+ return_transpose = False ,
319
+ )
320
+ w_quant , w_scale = kitchen_quant (
321
+ weight , backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = False , return_transpose = False
322
+ )
323
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((dout_quant , dout_scale ), (w_quant , w_scale ), dx )
324
+ dx = dx .reshape (dx_orig_shape )
325
+
326
+ # compute dw = mm(x_t, dout_t)
327
+ dout_t = dout .reshape ([- 1 , dout .shape [- 1 ]]).T .contiguous ()
328
+ # padding
329
+ if dout_t .shape [- 1 ] % 8 != 0 :
330
+ pad_size = 8 - (dout_t .shape [- 1 ] % 8 )
331
+ dout_t = paddle .concat ([dout_t , paddle .zeros ([dout_t .shape [0 ], pad_size ], dtype = dout_t .dtype )], axis = - 1 )
332
+
333
+ dout_t_quant , dout_t_scale = kitchen_quant (
334
+ dout_t , backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = True , return_transpose = False
335
+ )
336
+ dweight = kitchen_fp8_gemm (x_t_quant , x_t_scale , dout_t_quant , dout_t_scale , True , True )
337
+ return dx , dweight
338
+
339
+
340
+
270
341
271
342
class FP8Linear (paddle .nn .Layer ):
272
343
def __init__ (self , in_features : int , out_features : int , bias_attr : bool = False ) -> None :
@@ -282,6 +353,21 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
282
353
def forward (self , x ):
283
354
return LinearFP8Func .apply (x , self .weight )
284
355
356
+ class FP8KeepXLinear (paddle .nn .Layer ):
357
+ def __init__ (self , in_features : int , out_features : int , bias_attr : bool = False ) -> None :
358
+ super ().__init__ ()
359
+ self ._dtype = self ._helper .get_default_dtype ()
360
+
361
+ self .weight = self .create_parameter (
362
+ shape = [in_features , out_features ],
363
+ dtype = "bfloat16" ,
364
+ is_bias = False ,
365
+ )
366
+
367
+ def forward (self , x ):
368
+ return LinearFP8KeepXFunc .apply (x , self .weight )
369
+
370
+
285
371
286
372
class Fuse_FFN_FP8_Func (paddle .autograd .PyLayer ):
287
373
@staticmethod
0 commit comments