-
Notifications
You must be signed in to change notification settings - Fork 3k
[XPU] llama add xpu support #8282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
1293619
e388ed6
41421f4
e9a4b87
2a8c639
d9dcdbe
40c23a5
a3935fd
6e0316a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -413,6 +413,10 @@ def forward(self, hidden_states): | |
if self.config.use_fused_rms_norm: | ||
if get_env_device() == "npu": | ||
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0] | ||
elif get_env_device() == "xpu": | ||
import paddle_xpu_nn # noqa: F821 | ||
dynamicheart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] | ||
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) | ||
|
||
if paddle.in_dynamic_mode(): | ||
|
@@ -582,12 +586,38 @@ def __init__(self, config): | |
|
||
ColumnParallelLinear = MC2ColumnSeqParallelLinear | ||
RowParallelLinear = MC2RowSeqParallelLinear | ||
elif get_env_device() == "xpu": | ||
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 | ||
XPUColumnSequenceParallelLinear, | ||
XPURowSequenceParallelLinear, | ||
) | ||
|
||
ColumnParallelLinear = XPUColumnSequenceParallelLinear | ||
RowParallelLinear = XPURowSequenceParallelLinear | ||
else: | ||
ColumnParallelLinear = ColumnSequenceParallelLinear | ||
RowParallelLinear = RowSequenceParallelLinear | ||
else: | ||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear | ||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear | ||
if get_env_device() == "xpu": | ||
from paddle_xpu.layers.nn import ( # noqa: F401 | ||
ColumnParallelLinear as XPUColumnParallelLinear, | ||
) | ||
from paddle_xpu.layers.nn import ( # noqa: F401 | ||
RowParallelLinear as XPURowParallelLinear, | ||
) | ||
|
||
ColumnParallelLinear = XPUColumnParallelLinear | ||
RowParallelLinear = XPURowParallelLinear | ||
else: | ||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear | ||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear | ||
|
||
if get_env_device() == "xpu": | ||
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 | ||
|
||
Linear = XPULinear | ||
else: | ||
Linear = nn.Linear | ||
|
||
if config.tensor_parallel_degree > 1: | ||
if config.fuse_attention_ffn: | ||
|
@@ -619,15 +649,24 @@ def __init__(self, config): | |
) | ||
else: | ||
if config.fuse_attention_ffn: | ||
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) | ||
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) | ||
else: | ||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) | ||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) | ||
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) | ||
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) | ||
|
||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) | ||
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) | ||
|
||
def forward(self, x): | ||
if self.fuse_attention_ffn: | ||
# FIXME(yangjianbang): use paddle's native swiglu | ||
if get_env_device() == "xpu": | ||
import paddle_xpu_nn # noqa: F821 | ||
|
||
out = self.gate_up_fused_proj(x) | ||
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True) | ||
out = self.down_proj(out) | ||
return out | ||
|
||
x = swiglu(self.gate_up_fused_proj(x)) | ||
else: | ||
x = swiglu(self.gate_proj(x), self.up_proj(x)) | ||
|
@@ -688,7 +727,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): | |
) | ||
|
||
self.use_fused_rope = config.use_fused_rope | ||
if self.use_fused_rope and get_env_device() != "npu": | ||
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]: | ||
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: | ||
warnings.warn( | ||
"Enable fuse rope in the config, but fuse rope is not available. " | ||
|
@@ -705,12 +744,38 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): | |
|
||
ColumnParallelLinear = MC2ColumnSeqParallelLinear | ||
RowParallelLinear = MC2RowSeqParallelLinear | ||
elif get_env_device() == "xpu": | ||
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 | ||
XPUColumnSequenceParallelLinear, | ||
XPURowSequenceParallelLinear, | ||
) | ||
|
||
ColumnParallelLinear = XPUColumnSequenceParallelLinear | ||
RowParallelLinear = XPURowSequenceParallelLinear | ||
else: | ||
ColumnParallelLinear = ColumnSequenceParallelLinear | ||
RowParallelLinear = RowSequenceParallelLinear | ||
else: | ||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear | ||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear | ||
if get_env_device() == "xpu": | ||
from paddle_xpu.layers.nn import ( # noqa: F401 | ||
ColumnParallelLinear as XPUColumnParallelLinear, | ||
) | ||
from paddle_xpu.layers.nn import ( # noqa: F401 | ||
RowParallelLinear as XPURowParallelLinear, | ||
) | ||
|
||
ColumnParallelLinear = XPUColumnParallelLinear | ||
RowParallelLinear = XPURowParallelLinear | ||
else: | ||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear | ||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear | ||
|
||
if get_env_device() == "xpu": | ||
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 | ||
|
||
Linear = XPULinear | ||
else: | ||
Linear = nn.Linear | ||
|
||
if config.tensor_parallel_degree > 1: | ||
if self.fuse_attention_qkv: | ||
|
@@ -741,36 +806,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): | |
gather_output=False, | ||
) | ||
else: | ||
self.k_proj = nn.Linear( | ||
self.k_proj = Linear( | ||
self.hidden_size, | ||
self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
) | ||
self.v_proj = nn.Linear( | ||
self.v_proj = Linear( | ||
self.hidden_size, | ||
self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
) | ||
|
||
else: | ||
if self.fuse_attention_qkv: | ||
self.qkv_proj = nn.Linear( | ||
self.qkv_proj = Linear( | ||
self.hidden_size, | ||
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
) | ||
else: | ||
self.q_proj = nn.Linear( | ||
self.q_proj = Linear( | ||
self.hidden_size, | ||
self.hidden_size, | ||
bias_attr=False, | ||
) | ||
self.k_proj = nn.Linear( | ||
self.k_proj = Linear( | ||
self.hidden_size, | ||
self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
) | ||
self.v_proj = nn.Linear( | ||
self.v_proj = Linear( | ||
self.hidden_size, | ||
self.config.num_key_value_heads * self.head_dim, | ||
bias_attr=False, | ||
|
@@ -784,7 +849,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): | |
input_is_parallel=True, | ||
) | ||
else: | ||
self.o_proj = nn.Linear( | ||
self.o_proj = Linear( | ||
self.hidden_size, | ||
self.hidden_size, | ||
bias_attr=False, | ||
|
@@ -1428,6 +1493,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values | |
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16") | ||
expanded_attn_mask = expanded_attn_mask.astype("float16") | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) | ||
elif get_env_device() == "xpu": | ||
x = paddle.to_tensor(0.0, dtype=dtype) | ||
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) | ||
expanded_attn_mask = expanded_attn_mask.astype(dtype) | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当传入的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里和上面 npu 的逻辑看着差不多,可以复用吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 理论上是可以复用的,但是npu里面写死了dtype是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SylarTiaNII 看一下? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据 @wuhuachaocoding 意见,还是分成if elif两个单独的分支 |
||
else: | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) | ||
return expanded_attn_mask | ||
|
@@ -1708,6 +1778,12 @@ def __init__(self, config: LlamaConfig): | |
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False | ||
if self.weight.is_distributed: | ||
self.weight.split_axis = 1 | ||
if get_env_device() == "xpu": | ||
from paddle_xpu.layers.nn import ( # noqa: F401 | ||
parallel_matmul as xpu_parallel_matmul, | ||
) | ||
|
||
self.xpu_parallel_matmul = xpu_parallel_matmul() | ||
|
||
def forward(self, hidden_states, tensor_parallel_output=None): | ||
if self.config.sequence_parallel: | ||
|
@@ -1721,7 +1797,12 @@ def forward(self, hidden_states, tensor_parallel_output=None): | |
if tensor_parallel_output is None: | ||
tensor_parallel_output = self.config.tensor_parallel_output | ||
|
||
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) | ||
if get_env_device() == "xpu": | ||
logits = self.xpu_parallel_matmul( | ||
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training | ||
Comment on lines
+1742
to
+1743
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. training 参数必须要吗?如果参数能一样的话,是不是 把 parallel_matmul 的实现在xpu下替换就好了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里面有两个原因:
|
||
) | ||
else: | ||
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) | ||
return logits | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.