Skip to content

Optimize attention output linear fp8 memory #10204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

phlrain
Copy link
Collaborator

@phlrain phlrain commented Mar 19, 2025

output linear的输入 x ,是fa的输出,这个变量,在fa反向的时候是需要的,所以在output linear 使用fp的时候,save for backward可以直接用x,用x_fp8 会增加显存;单层decoder 大约能节省100M空间

Copy link

paddle-bot bot commented Mar 19, 2025

Thanks for your contribution!

@sneaxiy sneaxiy merged commit 0a77769 into PaddlePaddle:dsv3_dev Mar 19, 2025
1 of 5 checks passed
x_quant, x_scale = kitchen_quant(
x, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
)
weight_t = weight.T.contiguous()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些转置不会引起性能问题吗?

)


x_t_shape = x_t_shape.numpy()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_t_shape后面没有调用

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants