File tree Expand file tree Collapse file tree 2 files changed +7
-2
lines changed Expand file tree Collapse file tree 2 files changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -1211,6 +1211,8 @@ def sample(
1211
1211
probs = TopPProcess (probs , top_p , min_tokens_to_keep )
1212
1212
if paddle .device .is_compiled_with_custom_device ("gcu" ):
1213
1213
probs = paddle .cast (probs , "float32" )
1214
+ if paddle .device .is_compiled_with_xpu ():
1215
+ probs = paddle .cast (probs , "float32" )
1214
1216
1215
1217
# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
1216
1218
next_tokens = paddle .multinomial (probs )
Original file line number Diff line number Diff line change 42
42
load_state_dict ,
43
43
)
44
44
from ...transformers .utils import get_checkpoint_shard_files , weight_name_suffix
45
- from ...utils .distributed import distributed_gather
45
+ from ...utils .distributed import distributed_allgather , distributed_gather
46
46
from ...utils .env import LORA_WEIGHTS_NAME , SAFE_PEFT_WEIGHTS_INDEX_NAME
47
47
from ...utils .log import logger
48
48
from ...utils .tools import get_env_device
@@ -329,7 +329,10 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict):
329
329
for key in trainable_state_dict :
330
330
tensor = trainable_state_dict [key ]
331
331
if key in trainable_name_action_mappings :
332
- ret = distributed_gather (tensor , group = mp_group , offload = True )
332
+ if get_env_device () == "xpu" :
333
+ ret = distributed_allgather (tensor , group = mp_group , offload = True )
334
+ else :
335
+ ret = distributed_gather (tensor , group = mp_group , offload = True )
333
336
action = trainable_name_action_mappings [key ]
334
337
if key in self .lora_split_mapping and not self .lora_split_mapping [key ] and "_scale" in key and is_dst :
335
338
ret = paddle .to_tensor (ret )
You can’t perform that action at this time.
0 commit comments