Skip to content

Commit daf2f3a

Browse files
authored
fix load rng compatiblity. (#8450)
1 parent ebe397e commit daf2f3a

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,16 +1607,13 @@ def _load_rng_state(self, checkpoint):
16071607
if os.path.isfile(rng_file):
16081608
rng_file_list = paddle.load(rng_file, return_numpy=True)
16091609
paddle.distributed.broadcast_object_list(rng_file_list, src=0)
1610-
# if rng_file_list still empty, then use old style rng_state
1610+
# if rng_file_list still empty, not log rng state.
16111611
if rng_file_list[0] is None:
1612-
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
1613-
if not os.path.isfile(rng_file):
1614-
logger.info(
1615-
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
1616-
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
1617-
)
1618-
return
1619-
checkpoint_rng_state = paddle.load(rng_file, return_numpy=True)
1612+
logger.info(
1613+
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
1614+
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
1615+
)
1616+
return
16201617
else:
16211618
checkpoint_rng_state = rng_file_list[process_index]
16221619
else:

0 commit comments

Comments
 (0)