Skip to content

Commit 88d4b19

Browse files
authored
[XPU] use allgather and fp32 multinomial for XPU (#8787)
1 parent e8e59d0 commit 88d4b19

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

paddlenlp/generation/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,8 @@ def sample(
12111211
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
12121212
if paddle.device.is_compiled_with_custom_device("gcu"):
12131213
probs = paddle.cast(probs, "float32")
1214+
if paddle.device.is_compiled_with_xpu():
1215+
probs = paddle.cast(probs, "float32")
12141216

12151217
# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
12161218
next_tokens = paddle.multinomial(probs)

paddlenlp/peft/lora/lora_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
load_state_dict,
4343
)
4444
from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix
45-
from ...utils.distributed import distributed_gather
45+
from ...utils.distributed import distributed_allgather, distributed_gather
4646
from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME
4747
from ...utils.log import logger
4848
from ...utils.tools import get_env_device
@@ -329,7 +329,10 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict):
329329
for key in trainable_state_dict:
330330
tensor = trainable_state_dict[key]
331331
if key in trainable_name_action_mappings:
332-
ret = distributed_gather(tensor, group=mp_group, offload=True)
332+
if get_env_device() == "xpu":
333+
ret = distributed_allgather(tensor, group=mp_group, offload=True)
334+
else:
335+
ret = distributed_gather(tensor, group=mp_group, offload=True)
333336
action = trainable_name_action_mappings[key]
334337
if key in self.lora_split_mapping and not self.lora_split_mapping[key] and "_scale" in key and is_dst:
335338
ret = paddle.to_tensor(ret)

0 commit comments

Comments
 (0)