Skip to content

Commit 0a77769

Browse files
authored
Optimize attention output linear fp8 memory (#10204)
* optimize atten impl * optimize_attention_output_linear_fp8_memory
1 parent cf5f6a5 commit 0a77769

File tree

2 files changed

+94
-6
lines changed

2 files changed

+94
-6
lines changed

paddlenlp/transformers/deepseek_v2/fp8_linear.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,77 @@ def backward(ctx, dout):
267267
dweight = kitchen_fp8_gemm(x_t_quant, x_t_scale, dout_t_quant, dout_t_scale, True, True)
268268
return dx, dweight
269269

270+
class LinearFP8KeepXFunc(paddle.autograd.PyLayer):
271+
@staticmethod
272+
def forward(ctx, x, weight):
273+
x_orig_shape = x.shape
274+
# deep_gemm only support 2D
275+
x = x.reshape([-1, x_orig_shape[-1]])
276+
# quant
277+
x_quant, x_scale = kitchen_quant(
278+
x, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
279+
)
280+
weight_t = weight.T.contiguous()
281+
w_quant, w_scale = kitchen_quant(
282+
weight_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=False
283+
)
284+
285+
# compute out = mm(x, w_t)
286+
out = paddle.empty([x.shape[0], weight.shape[-1]], dtype=x.dtype)
287+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_quant, x_scale), (w_quant, w_scale), out)
288+
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])
289+
290+
291+
ctx.save_for_backward(
292+
x, weight
293+
)
294+
return out
295+
296+
@staticmethod
297+
def backward(ctx, dout):
298+
x, weight= ctx.saved_tensor()
299+
300+
# padding
301+
x_t = x.T.contiguous()
302+
if x_t.shape[-1] % 8 != 0:
303+
x_t = paddle.concat([x_t, paddle.zeros([x_t.shape[0], 8 - (x_t.shape[-1] % 8)], dtype=x_t.dtype)], axis=-1)
304+
x_t_quant, x_t_scale = kitchen_quant(
305+
x_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
306+
)
307+
308+
309+
x_t_shape = x_t_shape.numpy()
310+
# compute dx = mm(dout, w)
311+
dx = paddle.empty(x.shape, dout.dtype)
312+
dx_orig_shape = x.shape
313+
314+
dout_quant, dout_scale = kitchen_quant(
315+
dout.reshape([-1, dout.shape[-1]]),
316+
backend=kitchen.ops.Backend.CUTLASS,
317+
is_1d_scaled=True,
318+
return_transpose=False,
319+
)
320+
w_quant, w_scale = kitchen_quant(
321+
weight, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=False
322+
)
323+
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_quant, dout_scale), (w_quant, w_scale), dx)
324+
dx = dx.reshape(dx_orig_shape)
325+
326+
# compute dw = mm(x_t, dout_t)
327+
dout_t = dout.reshape([-1, dout.shape[-1]]).T.contiguous()
328+
# padding
329+
if dout_t.shape[-1] % 8 != 0:
330+
pad_size = 8 - (dout_t.shape[-1] % 8)
331+
dout_t = paddle.concat([dout_t, paddle.zeros([dout_t.shape[0], pad_size], dtype=dout_t.dtype)], axis=-1)
332+
333+
dout_t_quant, dout_t_scale = kitchen_quant(
334+
dout_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
335+
)
336+
dweight = kitchen_fp8_gemm(x_t_quant, x_t_scale, dout_t_quant, dout_t_scale, True, True)
337+
return dx, dweight
338+
339+
340+
270341

271342
class FP8Linear(paddle.nn.Layer):
272343
def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None:
@@ -282,6 +353,21 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
282353
def forward(self, x):
283354
return LinearFP8Func.apply(x, self.weight)
284355

356+
class FP8KeepXLinear(paddle.nn.Layer):
357+
def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None:
358+
super().__init__()
359+
self._dtype = self._helper.get_default_dtype()
360+
361+
self.weight = self.create_parameter(
362+
shape=[in_features, out_features],
363+
dtype="bfloat16",
364+
is_bias=False,
365+
)
366+
367+
def forward(self, x):
368+
return LinearFP8KeepXFunc.apply(x, self.weight)
369+
370+
285371

286372
class Fuse_FFN_FP8_Func(paddle.autograd.PyLayer):
287373
@staticmethod

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
from ..utils import device_guard
8080
from . import fp8_linear as linear_utils
8181
from .configuration import DeepseekV2Config
82-
from .fp8_linear import FP8DeepseekV2MLP, FP8Linear, Linear
82+
from .fp8_linear import FP8DeepseekV2MLP, FP8Linear, Linear, FP8KeepXLinear
8383

8484
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
8585
Linear = FP8Linear if DSV3_USE_FP8_GEMM else Linear
@@ -244,10 +244,9 @@ def scaled_dot_product_attention(
244244
)
245245

246246
if sequence_parallel:
247-
attn_output = outputs.reshape([bsz * q_len, v_head_dim * num_heads])
248-
else:
249-
attn_output = outputs.reshape([bsz, q_len, v_head_dim * num_heads])
250-
return attn_output
247+
outputs = outputs.reshape([bsz * q_len, v_head_dim * num_heads])
248+
249+
return outputs
251250

252251
else:
253252
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
@@ -1005,7 +1004,10 @@ def linear_dtype_gaurd():
10051004
with linear_dtype_gaurd():
10061005
self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
10071006
self.kv_b_proj = Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False)
1008-
self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
1007+
if DSV3_USE_FP8_GEMM:
1008+
self.o_proj = FP8KeepXLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
1009+
else:
1010+
self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
10091011
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank)
10101012

10111013
# fmt: on

0 commit comments

Comments
 (0)