Skip to content

Commit 534ec7e

Browse files
authored
fix dit training (#752)
1 parent 5b7c997 commit 534ec7e

File tree

3 files changed

+405
-16
lines changed

3 files changed

+405
-16
lines changed

ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
from .dit import DiT
1717
from .dit_llama import DiT_Llama
1818
from .respace import SpacedDiffusion, space_timesteps
19-
from .trainer import LatentDiffusionTrainer, LatentDiffusionAutoTrainer
19+
from .trainer import LatentDiffusionTrainer
20+
try:
21+
from paddlenlp.trainer.auto_trainer import AutoTrainer
22+
from .trainer_auto import LatentDiffusionAutoTrainer
23+
except:
24+
print(f'please install paddlepaddle-gpu>=3.0.0b2 if using auto trainer')
25+
2026
from .trainer_args import (
2127
DataArguments,
2228
ModelArguments,

ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from paddle.distributed import fleet
2424
from paddle.io import get_worker_info
2525
from paddlenlp.trainer import Trainer
26-
from paddlenlp.trainer.auto_trainer import AutoTrainer
2726
from paddlenlp.trainer.integrations import (
2827
INTEGRATION_TO_CALLBACK,
2928
TrainerCallback,
@@ -295,20 +294,6 @@ def __impl__():
295294

296295
return __impl__
297296

298-
class LatentDiffusionAutoTrainer(AutoTrainer):
299-
def __init__(self, *args, **kwargs):
300-
super().__init__(*args, **kwargs)
301-
302-
def _get_meshes_for_loader(self):
303-
def _get_mesh(pp_idx=0):
304-
return fleet.auto.get_mesh().get_mesh_with_dim("pp")[pp_idx]
305-
306-
return _get_mesh(0) # label_id is not label
307-
308-
def _wrap_for_dist_loader(self, train_dataloader):
309-
dist_loader = super()._wrap_for_dist_loader(train_dataloader)
310-
dist_loader._input_keys = ["latents", "label_id"]
311-
return dist_loader
312297

313298
class LatentDiffusionTrainer(Trainer):
314299
def __init__(self, **kwargs):

0 commit comments

Comments
 (0)