We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cdfa861 commit 1f82403Copy full SHA for 1f82403
paddlenlp/trainer/trainer.py
@@ -1548,7 +1548,9 @@ def _load_rng_state(self, checkpoint):
1548
if not len(checkpoint_rng_state["cuda"]) == core.get_custom_device_count(device):
1549
raise ValueError("Length of custom device state list shoule be equal to the custom device count")
1550
for i in range(core.get_custom_device_count(device)):
1551
- core.default_custom_device_generator(i).manual_seed(checkpoint_rng_state["cuda"][i])
+ core.default_custom_device_generator(paddle.CustomPlace(device, i)).manual_seed(
1552
+ checkpoint_rng_state["cuda"][i]
1553
+ )
1554
1555
if self.args.use_hybrid_parallel:
1556
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:
0 commit comments