Skip to content

Commit d385e9f

Browse files
committed
support rms_norm_mlu
1 parent 6d464bf commit d385e9f

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def swiglu(x, y=None):
4141
except ImportError:
4242
fused_rotary_position_embedding = None
4343
try:
44-
if get_env_device() in ["npu", "gcu"]:
44+
if get_env_device() in ["npu", "mlu", "gcu"]:
4545
from paddle.base import core
4646

4747
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
@@ -124,6 +124,8 @@ def rms_norm_fused(x_in, w, eps):
124124
def fusion_rms_norm(hidden_states, weight, variance_epsilon):
125125
if get_env_device() == "npu":
126126
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
127+
if get_env_device() == "mlu":
128+
return core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0]
127129
elif get_env_device() == "gcu":
128130
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]
129131
elif get_env_device() == "xpu":

paddlenlp/transformers/llama/modeling.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def swiglu(x, y=None):
7979
)
8080

8181
try:
82-
if get_env_device() in ["npu", "gcu"]:
82+
if get_env_device() in ["npu", "mlu", "gcu"]:
8383

8484
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
8585
if lib.endswith(".so"):
@@ -320,7 +320,7 @@ def _make_causal_mask(input_ids_shape, past_key_values_length):
320320
"""
321321
batch_size, target_length = input_ids_shape # target_length: seq_len
322322

323-
if get_env_device() == "npu":
323+
if get_env_device() == "npu" or get_env_device() == "mlu":
324324
mask = paddle.tril(paddle.ones((target_length, target_length))).astype("int32")
325325
else:
326326
mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
@@ -340,7 +340,7 @@ def _expand_2d_mask(mask, dtype, tgt_length):
340340
batch_size, src_length = mask.shape[0], mask.shape[-1]
341341
tgt_length = tgt_length if tgt_length is not None else src_length
342342

343-
if get_env_device() == "npu":
343+
if get_env_device() == "npu" or get_env_device() == "mlu":
344344
mask = mask[:, None, None, :].astype(dtype)
345345
else:
346346
mask = mask[:, None, None, :].astype("bool")
@@ -667,7 +667,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
667667
)
668668

669669
self.use_fused_rope = config.use_fused_rope
670-
if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]:
670+
if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu"]:
671671
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
672672
warnings.warn(
673673
"Enable fuse rope in the config, but fuse rope is not available. "
@@ -1429,7 +1429,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
14291429
combined_attention_mask = _make_causal_mask(
14301430
input_shape, past_key_values_length=past_key_values_length
14311431
)
1432-
if get_env_device() == "npu":
1432+
if get_env_device() == "npu" or get_env_device() == "mlu":
14331433
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
14341434
else:
14351435
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
@@ -1442,7 +1442,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
14421442
else:
14431443
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
14441444
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
1445-
if get_env_device() == "npu":
1445+
if get_env_device() == "npu" or get_env_device() == "mlu":
14461446
x = paddle.to_tensor(0.0, dtype="float32")
14471447
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
14481448
expanded_attn_mask = expanded_attn_mask.astype("float32")
@@ -1594,7 +1594,7 @@ def forward(
15941594
is_casual = True
15951595
else:
15961596
is_casual = is_casual_mask(attention_mask)
1597-
if get_env_device() != "npu":
1597+
if get_env_device() != "npu" or get_env_device() != "mlu":
15981598
if is_casual and alibi is None:
15991599
attention_mask = None
16001600
else:

paddlenlp/utils/tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def get_env_device():
124124
return "gpu"
125125
elif "npu" in paddle.device.get_all_custom_device_type():
126126
return "npu"
127+
elif "mlu" in paddle.device.get_all_custom_device_type():
128+
return "mlu"
127129
elif "gcu" in paddle.device.get_all_custom_device_type():
128130
return "gcu"
129131
elif paddle.is_compiled_with_rocm():

0 commit comments

Comments
 (0)