Skip to content

[XPU] llama add xpu support #8282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device


def add_start_docstrings(*docstr):
Expand Down Expand Up @@ -483,6 +484,12 @@ def main():
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)

print("Final pre-training config:", config)

# Set the dtype for loading model
Expand Down
115 changes: 98 additions & 17 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
import paddle_xpu_nn # noqa: F821

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
Expand Down Expand Up @@ -582,12 +586,38 @@ def __init__(self, config):

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if get_env_device() == "xpu":
from paddle_xpu.layers.nn import ( # noqa: F401
ColumnParallelLinear as XPUColumnParallelLinear,
)
from paddle_xpu.layers.nn import ( # noqa: F401
RowParallelLinear as XPURowParallelLinear,
)

ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401

Linear = XPULinear
else:
Linear = nn.Linear

if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
Expand Down Expand Up @@ -619,15 +649,24 @@ def __init__(self, config):
)
else:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)

def forward(self, x):
if self.fuse_attention_ffn:
# FIXME(yangjianbang): use paddle's native swiglu
if get_env_device() == "xpu":
import paddle_xpu_nn # noqa: F821

out = self.gate_up_fused_proj(x)
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
out = self.down_proj(out)
return out

x = swiglu(self.gate_up_fused_proj(x))
else:
x = swiglu(self.gate_proj(x), self.up_proj(x))
Expand Down Expand Up @@ -688,7 +727,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() != "npu":
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand All @@ -705,12 +744,38 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if get_env_device() == "xpu":
from paddle_xpu.layers.nn import ( # noqa: F401
ColumnParallelLinear as XPUColumnParallelLinear,
)
from paddle_xpu.layers.nn import ( # noqa: F401
RowParallelLinear as XPURowParallelLinear,
)

ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401

Linear = XPULinear
else:
Linear = nn.Linear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
Expand Down Expand Up @@ -741,36 +806,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
gather_output=False,
)
else:
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)

else:
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.qkv_proj = Linear(
self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -784,7 +849,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
input_is_parallel=True,
)
else:
self.o_proj = nn.Linear(
self.o_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
Expand Down Expand Up @@ -1428,6 +1493,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
expanded_attn_mask = expanded_attn_mask.astype("float16")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当传入的xy是整型scalar类型时,paddle.where 会将其视为int64、形状[1]的tensor,并会进行broadcast_add操作,详见search.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和上面 npu 的逻辑看着差不多,可以复用吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上是可以复用的,但是npu里面写死了dtype是float16,xpu跑的程序是可能是float16,也可能是bfloat16的。我们需要修改npu的模块么?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SylarTiaNII 看一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据 @wuhuachaocoding 意见,还是分成if elif两个单独的分支

else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask
Expand Down Expand Up @@ -1708,6 +1778,12 @@ def __init__(self, config: LlamaConfig):
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if self.weight.is_distributed:
self.weight.split_axis = 1
if get_env_device() == "xpu":
from paddle_xpu.layers.nn import ( # noqa: F401
parallel_matmul as xpu_parallel_matmul,
)

self.xpu_parallel_matmul = xpu_parallel_matmul()

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1721,7 +1797,12 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
if get_env_device() == "xpu":
logits = self.xpu_parallel_matmul(
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training
Comment on lines +1742 to +1743
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training 参数必须要吗?如果参数能一样的话,是不是 把 parallel_matmul 的实现在xpu下替换就好了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面有两个原因:

  • XPU的一个优化是需要将parallel_matmul作为一个对象来存储某些状态
  • XPU需要training信息来进行优化

)
else:
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
return logits


Expand Down