Skip to content

Commit 72ef18a

Browse files
committed
[XPU] llama add xpu support
1 parent 0790824 commit 72ef18a

File tree

2 files changed

+97
-17
lines changed

2 files changed

+97
-17
lines changed

llm/run_pretrain.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
4848
from paddlenlp.utils.log import logger
49+
from paddlenlp.utils.tools import get_env_device
4950

5051

5152
def add_start_docstrings(*docstr):
@@ -483,6 +484,12 @@ def main():
483484
config.num_attention_heads % config.sep_parallel_degree == 0
484485
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"
485486

487+
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
488+
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
489+
490+
LinearConfig.enable_accumulate_steps_opt()
491+
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
492+
486493
print("Final pre-training config:", config)
487494

488495
# Set the dtype for loading model

paddlenlp/transformers/llama/modeling.py

Lines changed: 90 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,10 @@ def forward(self, hidden_states):
413413
if self.config.use_fused_rms_norm:
414414
if get_env_device() == "npu":
415415
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
416+
elif get_env_device() == "xpu":
417+
import paddle_xpu_nn
418+
419+
return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
416420
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
417421

418422
if paddle.in_dynamic_mode():
@@ -582,12 +586,33 @@ def __init__(self, config):
582586

583587
ColumnParallelLinear = MC2ColumnSeqParallelLinear
584588
RowParallelLinear = MC2RowSeqParallelLinear
589+
elif get_env_device() == "xpu":
590+
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
591+
XPUColumnSequenceParallelLinear,
592+
XPURowSequenceParallelLinear,
593+
)
594+
595+
ColumnParallelLinear = XPUColumnSequenceParallelLinear
596+
RowParallelLinear = XPURowSequenceParallelLinear
585597
else:
586598
ColumnParallelLinear = ColumnSequenceParallelLinear
587599
RowParallelLinear = RowSequenceParallelLinear
588600
else:
589-
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
590-
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
601+
if get_env_device() == "xpu":
602+
import paddle_xpu # noqa: F821
603+
604+
ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear
605+
RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear
606+
else:
607+
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
608+
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
609+
610+
if get_env_device() == "xpu":
611+
import paddle_xpu # noqa: F821
612+
613+
Linear = paddle_xpu.layers.nn.Linear
614+
else:
615+
Linear = nn.Linear
591616

592617
if config.tensor_parallel_degree > 1:
593618
if config.fuse_attention_ffn:
@@ -619,15 +644,24 @@ def __init__(self, config):
619644
)
620645
else:
621646
if config.fuse_attention_ffn:
622-
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
647+
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
623648
else:
624-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
625-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
649+
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
650+
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
626651

627-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
652+
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
628653

629654
def forward(self, x):
630655
if self.fuse_attention_ffn:
656+
# FIXME(yangjianbang): use paddle's native swiglu
657+
if get_env_device() == "xpu":
658+
import paddle_xpu_nn # noqa: F821
659+
660+
out = self.gate_up_fused_proj(x)
661+
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
662+
out = self.down_proj(out)
663+
return out
664+
631665
x = swiglu(self.gate_up_fused_proj(x))
632666
else:
633667
x = swiglu(self.gate_proj(x), self.up_proj(x))
@@ -689,7 +723,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
689723

690724
self.use_fused_rope = config.use_fused_rope
691725
if self.use_fused_rope and get_env_device() != "npu":
692-
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
726+
if (
727+
"gpu" not in paddle.device.get_device()
728+
or "xpu" not in paddle.device.get_device()
729+
or fused_rotary_position_embedding is None
730+
):
693731
warnings.warn(
694732
"Enable fuse rope in the config, but fuse rope is not available. "
695733
"Will disable fuse rope. Try using latest gpu version of Paddle."
@@ -705,12 +743,33 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
705743

706744
ColumnParallelLinear = MC2ColumnSeqParallelLinear
707745
RowParallelLinear = MC2RowSeqParallelLinear
746+
elif get_env_device() == "xpu":
747+
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
748+
XPUColumnSequenceParallelLinear,
749+
XPURowSequenceParallelLinear,
750+
)
751+
752+
ColumnParallelLinear = XPUColumnSequenceParallelLinear
753+
RowParallelLinear = XPURowSequenceParallelLinear
708754
else:
709755
ColumnParallelLinear = ColumnSequenceParallelLinear
710756
RowParallelLinear = RowSequenceParallelLinear
711757
else:
712-
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
713-
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
758+
if get_env_device() == "xpu":
759+
import paddle_xpu # noqa: F821
760+
761+
ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear # noqa: F821
762+
RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear # noqa: F821
763+
else:
764+
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
765+
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
766+
767+
if get_env_device() == "xpu":
768+
import paddle_xpu # noqa: F821
769+
770+
Linear = paddle_xpu.layers.nn.Linear
771+
else:
772+
Linear = nn.Linear
714773

715774
if config.tensor_parallel_degree > 1:
716775
if self.fuse_attention_qkv:
@@ -741,36 +800,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
741800
gather_output=False,
742801
)
743802
else:
744-
self.k_proj = nn.Linear(
803+
self.k_proj = Linear(
745804
self.hidden_size,
746805
self.config.num_key_value_heads * self.head_dim,
747806
bias_attr=False,
748807
)
749-
self.v_proj = nn.Linear(
808+
self.v_proj = Linear(
750809
self.hidden_size,
751810
self.config.num_key_value_heads * self.head_dim,
752811
bias_attr=False,
753812
)
754813

755814
else:
756815
if self.fuse_attention_qkv:
757-
self.qkv_proj = nn.Linear(
816+
self.qkv_proj = Linear(
758817
self.hidden_size,
759818
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
760819
bias_attr=False,
761820
)
762821
else:
763-
self.q_proj = nn.Linear(
822+
self.q_proj = Linear(
764823
self.hidden_size,
765824
self.hidden_size,
766825
bias_attr=False,
767826
)
768-
self.k_proj = nn.Linear(
827+
self.k_proj = Linear(
769828
self.hidden_size,
770829
self.config.num_key_value_heads * self.head_dim,
771830
bias_attr=False,
772831
)
773-
self.v_proj = nn.Linear(
832+
self.v_proj = Linear(
774833
self.hidden_size,
775834
self.config.num_key_value_heads * self.head_dim,
776835
bias_attr=False,
@@ -784,7 +843,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
784843
input_is_parallel=True,
785844
)
786845
else:
787-
self.o_proj = nn.Linear(
846+
self.o_proj = Linear(
788847
self.hidden_size,
789848
self.hidden_size,
790849
bias_attr=False,
@@ -1428,6 +1487,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
14281487
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
14291488
expanded_attn_mask = expanded_attn_mask.astype("float16")
14301489
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
1490+
elif get_env_device() == "xpu":
1491+
x = paddle.to_tensor(0.0, dtype=dtype)
1492+
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
1493+
expanded_attn_mask = expanded_attn_mask.astype(dtype)
1494+
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
14311495
else:
14321496
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
14331497
return expanded_attn_mask
@@ -1708,6 +1772,10 @@ def __init__(self, config: LlamaConfig):
17081772
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
17091773
if self.weight.is_distributed:
17101774
self.weight.split_axis = 1
1775+
if get_env_device() == "xpu":
1776+
import paddle_xpu
1777+
1778+
self.xpu_parallel_matmul = paddle_xpu.layers.nn.parallel_matmul()
17111779

17121780
def forward(self, hidden_states, tensor_parallel_output=None):
17131781
if self.config.sequence_parallel:
@@ -1721,7 +1789,12 @@ def forward(self, hidden_states, tensor_parallel_output=None):
17211789
if tensor_parallel_output is None:
17221790
tensor_parallel_output = self.config.tensor_parallel_output
17231791

1724-
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
1792+
if get_env_device() == "xpu":
1793+
logits = self.xpu_parallel_matmul(
1794+
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training
1795+
)
1796+
else:
1797+
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
17251798
return logits
17261799

17271800

0 commit comments

Comments
 (0)