Skip to content

Commit 582f109

Browse files
zhangyuqin1998Mangodadada
authored andcommitted
[Auto Parallel] fix data stream bug of dist.to_static (PaddlePaddle#9077)
1 parent 549dcf8 commit 582f109

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,12 @@ def _wrap_for_auto(self, model, train_dataloader):
127127
if self.args.to_static:
128128
unified_strategy = dist.Strategy()
129129
unified_strategy._from_legacy_strategy(self.args.strategy)
130-
model = dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=unified_strategy)
130+
# dist.to_static() obtains the input spec information through next(dataloader), but this has side effects
131+
# on the passed-in dataloader, altering the state of the sampler of the dataloader. In some cases, once
132+
# the state of the sampler is changed, it cannot be reverted. Therefore, a temporary dataloader is
133+
# constructed here to avoid side effects on the dataloader used for actual training.
134+
temp_loader = self._wrap_for_dist_loader(self.get_train_dataloader())
135+
model = dist.to_static(model, temp_loader, self.criterion, self.optimizer, strategy=unified_strategy)
131136

132137
self.model_wrapped = model
133138
return model, dist_loader

0 commit comments

Comments
 (0)