Skip to content

Commit 1f82403

Browse files
authored
[CustomDevice] fix loading rng state on custom device (#7894)
1 parent cdfa861 commit 1f82403

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,9 @@ def _load_rng_state(self, checkpoint):
15481548
if not len(checkpoint_rng_state["cuda"]) == core.get_custom_device_count(device):
15491549
raise ValueError("Length of custom device state list shoule be equal to the custom device count")
15501550
for i in range(core.get_custom_device_count(device)):
1551-
core.default_custom_device_generator(i).manual_seed(checkpoint_rng_state["cuda"][i])
1551+
core.default_custom_device_generator(paddle.CustomPlace(device, i)).manual_seed(
1552+
checkpoint_rng_state["cuda"][i]
1553+
)
15521554

15531555
if self.args.use_hybrid_parallel:
15541556
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:

0 commit comments

Comments
 (0)