File tree Expand file tree Collapse file tree 3 files changed +405
-16
lines changed
ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion Expand file tree Collapse file tree 3 files changed +405
-16
lines changed Original file line number Diff line number Diff line change 16
16
from .dit import DiT
17
17
from .dit_llama import DiT_Llama
18
18
from .respace import SpacedDiffusion , space_timesteps
19
- from .trainer import LatentDiffusionTrainer , LatentDiffusionAutoTrainer
19
+ from .trainer import LatentDiffusionTrainer
20
+ try :
21
+ from paddlenlp .trainer .auto_trainer import AutoTrainer
22
+ from .trainer_auto import LatentDiffusionAutoTrainer
23
+ except :
24
+ print (f'please install paddlepaddle-gpu>=3.0.0b2 if using auto trainer' )
25
+
20
26
from .trainer_args import (
21
27
DataArguments ,
22
28
ModelArguments ,
Original file line number Diff line number Diff line change 23
23
from paddle .distributed import fleet
24
24
from paddle .io import get_worker_info
25
25
from paddlenlp .trainer import Trainer
26
- from paddlenlp .trainer .auto_trainer import AutoTrainer
27
26
from paddlenlp .trainer .integrations import (
28
27
INTEGRATION_TO_CALLBACK ,
29
28
TrainerCallback ,
@@ -295,20 +294,6 @@ def __impl__():
295
294
296
295
return __impl__
297
296
298
- class LatentDiffusionAutoTrainer (AutoTrainer ):
299
- def __init__ (self , * args , ** kwargs ):
300
- super ().__init__ (* args , ** kwargs )
301
-
302
- def _get_meshes_for_loader (self ):
303
- def _get_mesh (pp_idx = 0 ):
304
- return fleet .auto .get_mesh ().get_mesh_with_dim ("pp" )[pp_idx ]
305
-
306
- return _get_mesh (0 ) # label_id is not label
307
-
308
- def _wrap_for_dist_loader (self , train_dataloader ):
309
- dist_loader = super ()._wrap_for_dist_loader (train_dataloader )
310
- dist_loader ._input_keys = ["latents" , "label_id" ]
311
- return dist_loader
312
297
313
298
class LatentDiffusionTrainer (Trainer ):
314
299
def __init__ (self , ** kwargs ):
You can’t perform that action at this time.
0 commit comments