@@ -79,7 +79,7 @@ def swiglu(x, y=None):
79
79
)
80
80
81
81
try :
82
- if get_env_device () in ["npu" , "gcu" ]:
82
+ if get_env_device () in ["npu" , "mlu" , " gcu" ]:
83
83
84
84
for lib in os .listdir (os .getenv ("CUSTOM_DEVICE_ROOT" )):
85
85
if lib .endswith (".so" ):
@@ -311,7 +311,7 @@ def _make_causal_mask(input_ids_shape, past_key_values_length):
311
311
"""
312
312
batch_size , target_length = input_ids_shape # target_length: seq_len
313
313
314
- if get_env_device () == "npu" :
314
+ if get_env_device () == "npu" or get_env_device () == "mlu" :
315
315
mask = paddle .tril (paddle .ones ((target_length , target_length ))).astype ("int32" )
316
316
else :
317
317
mask = paddle .tril (paddle .ones ((target_length , target_length ), dtype = "bool" ))
@@ -331,7 +331,7 @@ def _expand_2d_mask(mask, dtype, tgt_length):
331
331
batch_size , src_length = mask .shape [0 ], mask .shape [- 1 ]
332
332
tgt_length = tgt_length if tgt_length is not None else src_length
333
333
334
- if get_env_device () == "npu" :
334
+ if get_env_device () == "npu" or get_env_device () == "mlu" :
335
335
mask = mask [:, None , None , :].astype (dtype )
336
336
else :
337
337
mask = mask [:, None , None , :].astype ("bool" )
@@ -657,7 +657,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
657
657
)
658
658
659
659
self .use_fused_rope = config .use_fused_rope
660
- if self .use_fused_rope and get_env_device () not in ["npu" , "xpu" , "gcu" ]:
660
+ if self .use_fused_rope and get_env_device () not in ["npu" , "mlu" , " xpu" , "gcu" ]:
661
661
if "gpu" not in paddle .device .get_device () or fused_rotary_position_embedding is None :
662
662
warnings .warn (
663
663
"Enable fuse rope in the config, but fuse rope is not available. "
@@ -1399,7 +1399,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
1399
1399
combined_attention_mask = _make_causal_mask (
1400
1400
input_shape , past_key_values_length = past_key_values_length
1401
1401
)
1402
- if get_env_device () == "npu" :
1402
+ if get_env_device () == "npu" or get_env_device () == "mlu" :
1403
1403
expanded_attn_mask = expanded_attn_mask .astype ("bool" ) & combined_attention_mask .astype ("bool" )
1404
1404
else :
1405
1405
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
@@ -1412,7 +1412,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
1412
1412
else :
1413
1413
expanded_attn_mask = _make_causal_mask (input_shape , past_key_values_length = past_key_values_length )
1414
1414
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
1415
- if get_env_device () == "npu" :
1415
+ if get_env_device () == "npu" or get_env_device () == "mlu" :
1416
1416
x = paddle .to_tensor (0.0 , dtype = "float32" )
1417
1417
y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = "float32" )
1418
1418
expanded_attn_mask = expanded_attn_mask .astype ("float32" )
@@ -1549,7 +1549,7 @@ def forward(
1549
1549
is_casual = False
1550
1550
if self .config .use_flash_attention and get_env_device () != "gcu" :
1551
1551
is_casual = is_casual_mask (attention_mask )
1552
- if get_env_device () != "npu" :
1552
+ if get_env_device () != "npu" or get_env_device () != "mlu" :
1553
1553
if is_casual and alibi is None :
1554
1554
attention_mask = None
1555
1555
else :
0 commit comments