Skip to content

Optimize attention output linear fp8 memory #10204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions paddlenlp/transformers/deepseek_v2/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,77 @@ def backward(ctx, dout):
dweight = kitchen_fp8_gemm(x_t_quant, x_t_scale, dout_t_quant, dout_t_scale, True, True)
return dx, dweight

class LinearFP8KeepXFunc(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight):
x_orig_shape = x.shape
# deep_gemm only support 2D
x = x.reshape([-1, x_orig_shape[-1]])
# quant
x_quant, x_scale = kitchen_quant(
x, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
)
weight_t = weight.T.contiguous()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些转置不会引起性能问题吗?

w_quant, w_scale = kitchen_quant(
weight_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=False
)

# compute out = mm(x, w_t)
out = paddle.empty([x.shape[0], weight.shape[-1]], dtype=x.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_quant, x_scale), (w_quant, w_scale), out)
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])


ctx.save_for_backward(
x, weight
)
return out

@staticmethod
def backward(ctx, dout):
x, weight= ctx.saved_tensor()

# padding
x_t = x.T.contiguous()
if x_t.shape[-1] % 8 != 0:
x_t = paddle.concat([x_t, paddle.zeros([x_t.shape[0], 8 - (x_t.shape[-1] % 8)], dtype=x_t.dtype)], axis=-1)
x_t_quant, x_t_scale = kitchen_quant(
x_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
)


x_t_shape = x_t_shape.numpy()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_t_shape后面没有调用

# compute dx = mm(dout, w)
dx = paddle.empty(x.shape, dout.dtype)
dx_orig_shape = x.shape

dout_quant, dout_scale = kitchen_quant(
dout.reshape([-1, dout.shape[-1]]),
backend=kitchen.ops.Backend.CUTLASS,
is_1d_scaled=True,
return_transpose=False,
)
w_quant, w_scale = kitchen_quant(
weight, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=False
)
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_quant, dout_scale), (w_quant, w_scale), dx)
dx = dx.reshape(dx_orig_shape)

# compute dw = mm(x_t, dout_t)
dout_t = dout.reshape([-1, dout.shape[-1]]).T.contiguous()
# padding
if dout_t.shape[-1] % 8 != 0:
pad_size = 8 - (dout_t.shape[-1] % 8)
dout_t = paddle.concat([dout_t, paddle.zeros([dout_t.shape[0], pad_size], dtype=dout_t.dtype)], axis=-1)

dout_t_quant, dout_t_scale = kitchen_quant(
dout_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
dweight = kitchen_fp8_gemm(x_t_quant, x_t_scale, dout_t_quant, dout_t_scale, True, True)
return dx, dweight




class FP8Linear(paddle.nn.Layer):
def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None:
Expand All @@ -282,6 +353,21 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
def forward(self, x):
return LinearFP8Func.apply(x, self.weight)

class FP8KeepXLinear(paddle.nn.Layer):
def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None:
super().__init__()
self._dtype = self._helper.get_default_dtype()

self.weight = self.create_parameter(
shape=[in_features, out_features],
dtype="bfloat16",
is_bias=False,
)

def forward(self, x):
return LinearFP8KeepXFunc.apply(x, self.weight)



class Fuse_FFN_FP8_Func(paddle.autograd.PyLayer):
@staticmethod
Expand Down
14 changes: 8 additions & 6 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from ..utils import device_guard
from . import fp8_linear as linear_utils
from .configuration import DeepseekV2Config
from .fp8_linear import FP8DeepseekV2MLP, FP8Linear, Linear
from .fp8_linear import FP8DeepseekV2MLP, FP8Linear, Linear, FP8KeepXLinear

DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
Linear = FP8Linear if DSV3_USE_FP8_GEMM else Linear
Expand Down Expand Up @@ -244,10 +244,9 @@ def scaled_dot_product_attention(
)

if sequence_parallel:
attn_output = outputs.reshape([bsz * q_len, v_head_dim * num_heads])
else:
attn_output = outputs.reshape([bsz, q_len, v_head_dim * num_heads])
return attn_output
outputs = outputs.reshape([bsz * q_len, v_head_dim * num_heads])

return outputs

else:
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
Expand Down Expand Up @@ -1005,7 +1004,10 @@ def linear_dtype_gaurd():
with linear_dtype_gaurd():
self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
self.kv_b_proj = Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False)
self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
if DSV3_USE_FP8_GEMM:
self.o_proj = FP8KeepXLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
else:
self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank)

# fmt: on
Expand Down
Loading