Skip to content

Commit 92fe84b

Browse files
committed
support rms_norm_mlu
1 parent c1cfe63 commit 92fe84b

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def swiglu(x, y=None):
4141
except ImportError:
4242
fused_rotary_position_embedding = None
4343
try:
44-
if get_env_device() in ["npu", "gcu"]:
44+
if get_env_device() in ["npu", "mlu", "gcu"]:
4545
from paddle.base import core
4646

4747
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
@@ -110,6 +110,8 @@ def rms_norm_fused(x_in, w, eps):
110110
def fusion_rms_norm(hidden_states, weight, variance_epsilon):
111111
if get_env_device() == "npu":
112112
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
113+
if get_env_device() == "mlu":
114+
return core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0]
113115
elif get_env_device() == "gcu":
114116
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]
115117
elif get_env_device() == "xpu":

paddlenlp/transformers/llama/modeling.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def swiglu(x, y=None):
7979
)
8080

8181
try:
82-
if get_env_device() in ["npu", "gcu"]:
82+
if get_env_device() in ["npu", "mlu", "gcu"]:
8383

8484
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
8585
if lib.endswith(".so"):
@@ -311,7 +311,7 @@ def _make_causal_mask(input_ids_shape, past_key_values_length):
311311
"""
312312
batch_size, target_length = input_ids_shape # target_length: seq_len
313313

314-
if get_env_device() == "npu":
314+
if get_env_device() == "npu" or get_env_device() == "mlu":
315315
mask = paddle.tril(paddle.ones((target_length, target_length))).astype("int32")
316316
else:
317317
mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
@@ -331,7 +331,7 @@ def _expand_2d_mask(mask, dtype, tgt_length):
331331
batch_size, src_length = mask.shape[0], mask.shape[-1]
332332
tgt_length = tgt_length if tgt_length is not None else src_length
333333

334-
if get_env_device() == "npu":
334+
if get_env_device() == "npu" or get_env_device() == "mlu":
335335
mask = mask[:, None, None, :].astype(dtype)
336336
else:
337337
mask = mask[:, None, None, :].astype("bool")
@@ -657,7 +657,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
657657
)
658658

659659
self.use_fused_rope = config.use_fused_rope
660-
if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]:
660+
if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu"]:
661661
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
662662
warnings.warn(
663663
"Enable fuse rope in the config, but fuse rope is not available. "
@@ -1399,7 +1399,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
13991399
combined_attention_mask = _make_causal_mask(
14001400
input_shape, past_key_values_length=past_key_values_length
14011401
)
1402-
if get_env_device() == "npu":
1402+
if get_env_device() == "npu" or get_env_device() == "mlu":
14031403
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
14041404
else:
14051405
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
@@ -1412,7 +1412,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
14121412
else:
14131413
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
14141414
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
1415-
if get_env_device() == "npu":
1415+
if get_env_device() == "npu" or get_env_device() == "mlu":
14161416
x = paddle.to_tensor(0.0, dtype="float32")
14171417
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
14181418
expanded_attn_mask = expanded_attn_mask.astype("float32")
@@ -1549,7 +1549,7 @@ def forward(
15491549
is_casual = False
15501550
if self.config.use_flash_attention and get_env_device() != "gcu":
15511551
is_casual = is_casual_mask(attention_mask)
1552-
if get_env_device() != "npu":
1552+
if get_env_device() != "npu" or get_env_device() != "mlu":
15531553
if is_casual and alibi is None:
15541554
attention_mask = None
15551555
else:

paddlenlp/utils/tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def get_env_device():
124124
return "gpu"
125125
elif "npu" in paddle.device.get_all_custom_device_type():
126126
return "npu"
127+
elif "mlu" in paddle.device.get_all_custom_device_type():
128+
return "mlu"
127129
elif "gcu" in paddle.device.get_all_custom_device_type():
128130
return "gcu"
129131
elif paddle.is_compiled_with_rocm():

0 commit comments

Comments
 (0)