@@ -79,7 +79,7 @@ def swiglu(x, y=None):
79
79
)
80
80
81
81
try :
82
- if get_env_device () in ["npu" , "gcu" ]:
82
+ if get_env_device () in ["npu" , "mlu" , " gcu" ]:
83
83
84
84
for lib in os .listdir (os .getenv ("CUSTOM_DEVICE_ROOT" )):
85
85
if lib .endswith (".so" ):
@@ -320,7 +320,7 @@ def _make_causal_mask(input_ids_shape, past_key_values_length):
320
320
"""
321
321
batch_size , target_length = input_ids_shape # target_length: seq_len
322
322
323
- if get_env_device () == "npu" :
323
+ if get_env_device () == "npu" or get_env_device () == "mlu" :
324
324
mask = paddle .tril (paddle .ones ((target_length , target_length ))).astype ("int32" )
325
325
else :
326
326
mask = paddle .tril (paddle .ones ((target_length , target_length ), dtype = "bool" ))
@@ -340,7 +340,7 @@ def _expand_2d_mask(mask, dtype, tgt_length):
340
340
batch_size , src_length = mask .shape [0 ], mask .shape [- 1 ]
341
341
tgt_length = tgt_length if tgt_length is not None else src_length
342
342
343
- if get_env_device () == "npu" :
343
+ if get_env_device () == "npu" or get_env_device () == "mlu" :
344
344
mask = mask [:, None , None , :].astype (dtype )
345
345
else :
346
346
mask = mask [:, None , None , :].astype ("bool" )
@@ -667,7 +667,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
667
667
)
668
668
669
669
self .use_fused_rope = config .use_fused_rope
670
- if self .use_fused_rope and get_env_device () not in ["npu" , "xpu" , "gcu" ]:
670
+ if self .use_fused_rope and get_env_device () not in ["npu" , "mlu" , " xpu" , "gcu" ]:
671
671
if "gpu" not in paddle .device .get_device () or fused_rotary_position_embedding is None :
672
672
warnings .warn (
673
673
"Enable fuse rope in the config, but fuse rope is not available. "
@@ -1429,7 +1429,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
1429
1429
combined_attention_mask = _make_causal_mask (
1430
1430
input_shape , past_key_values_length = past_key_values_length
1431
1431
)
1432
- if get_env_device () == "npu" :
1432
+ if get_env_device () == "npu" or get_env_device () == "mlu" :
1433
1433
expanded_attn_mask = expanded_attn_mask .astype ("bool" ) & combined_attention_mask .astype ("bool" )
1434
1434
else :
1435
1435
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
@@ -1442,7 +1442,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
1442
1442
else :
1443
1443
expanded_attn_mask = _make_causal_mask (input_shape , past_key_values_length = past_key_values_length )
1444
1444
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
1445
- if get_env_device () == "npu" :
1445
+ if get_env_device () == "npu" or get_env_device () == "mlu" :
1446
1446
x = paddle .to_tensor (0.0 , dtype = "float32" )
1447
1447
y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = "float32" )
1448
1448
expanded_attn_mask = expanded_attn_mask .astype ("float32" )
@@ -1594,7 +1594,7 @@ def forward(
1594
1594
is_casual = True
1595
1595
else :
1596
1596
is_casual = is_casual_mask (attention_mask )
1597
- if get_env_device () != "npu" :
1597
+ if get_env_device () != "npu" or get_env_device () != "mlu" :
1598
1598
if is_casual and alibi is None :
1599
1599
attention_mask = None
1600
1600
else :
0 commit comments