Skip to content

refine fp8 Linear #10191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions paddlenlp/transformers/deepseek_v2/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ def forward(ctx, x, weight):
x_quant, x_scale = kitchen_quant(
x, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
)

w_quant, w_sacle, w_t_quant, w_t_scale = kitchen_quant(
weight, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=True
weight_t = weight.T.contiguous()
w_quant, w_scale = kitchen_quant(
weight_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=False
)

# compute out = mm(x, w_t)
out = paddle.empty([x.shape[0], weight.shape[-1]], dtype=x.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_quant, x_scale), (w_t_quant, w_t_scale), out)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_quant, x_scale), (w_quant, w_scale), out)
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])

# save for bwd
Expand All @@ -230,13 +230,13 @@ def forward(ctx, x, weight):
x_t.contiguous(), backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
)
ctx.save_for_backward(
x_t_quant, x_t_scale, w_quant, w_sacle, paddle.to_tensor(x_t_shape, dtype="int64", place=paddle.CPUPlace())
x_t_quant, x_t_scale, weight, paddle.to_tensor(x_t_shape, dtype="int64", place=paddle.CPUPlace())
)
return out

@staticmethod
def backward(ctx, dout):
x_t_quant, x_t_scale, w_quant, w_sacle, x_t_shape = ctx.saved_tensor()
x_t_quant, x_t_scale, weight, x_t_shape = ctx.saved_tensor()
x_t_shape = x_t_shape.numpy()
# compute dx = mm(dout, w)
dx = paddle.empty([x_t_shape[1], x_t_shape[0]], dout.dtype)
Expand All @@ -248,7 +248,10 @@ def backward(ctx, dout):
is_1d_scaled=True,
return_transpose=False,
)
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_quant, dout_scale), (w_quant, w_sacle), dx)
w_quant, w_scale = kitchen_quant(
weight, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=False
)
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_quant, dout_scale), (w_quant, w_scale), dx)
dx = dx.reshape(dx_orig_shape)

# compute dw = mm(x_t, dout_t)
Expand Down
Loading