Skip to content

Commit 6f79f16

Browse files
cherry-pick 7876 7895 7894 form develop (#8009)
1 parent 724b524 commit 6f79f16

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def train(
631631
# The resume_from_checkpoint could be None in some machine node.
632632
# Here we reset None to temp directory.
633633
if args.world_size > 1:
634-
is_resume_from_checkpoint = paddle.to_tensor([resume_from_checkpoint is not None])
634+
is_resume_from_checkpoint = paddle.to_tensor([resume_from_checkpoint is not None], dtype="int32")
635635
paddle.distributed.all_reduce(is_resume_from_checkpoint)
636636
is_resume_from_checkpoint = is_resume_from_checkpoint.item()
637637
if is_resume_from_checkpoint > 0 and is_resume_from_checkpoint < paddle.distributed.get_world_size():
@@ -1556,7 +1556,9 @@ def _load_rng_state(self, checkpoint):
15561556
if not len(checkpoint_rng_state["cuda"]) == core.get_custom_device_count(device):
15571557
raise ValueError("Length of custom device state list shoule be equal to the custom device count")
15581558
for i in range(core.get_custom_device_count(device)):
1559-
core.default_custom_device_generator(i).manual_seed(checkpoint_rng_state["cuda"][i])
1559+
core.default_custom_device_generator(paddle.CustomPlace(device, i)).manual_seed(
1560+
checkpoint_rng_state["cuda"][i]
1561+
)
15601562

15611563
if self.args.use_hybrid_parallel:
15621564
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:

paddlenlp/transformers/llama/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,7 @@ def forward(self, prediction_scores, masked_lm_labels):
15241524
_hcg = fleet.get_hybrid_communicate_group()
15251525
masked_lm_loss = ConcatSePMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group())
15261526
# skip ignore_index which loss == 0
1527-
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32")
1527+
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
15281528
loss = paddle.mean(masked_lm_loss)
15291529

15301530
return loss

0 commit comments

Comments
 (0)