1
1
import argparse
2
- import copy
3
2
import logging
4
3
import math
5
4
import os
6
5
import random
7
6
from pathlib import Path
8
- from typing import Optional
7
+ from typing import Iterable , Optional
9
8
10
9
import numpy as np
11
10
import torch
@@ -234,25 +233,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
234
233
}
235
234
236
235
236
+ # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
237
237
class EMAModel :
238
238
"""
239
239
Exponential Moving Average of models weights
240
240
"""
241
241
242
- def __init__ (
243
- self ,
244
- model ,
245
- decay = 0.9999 ,
246
- device = None ,
247
- ):
248
- self .averaged_model = copy .deepcopy (model ).eval ()
249
- self .averaged_model .requires_grad_ (False )
242
+ def __init__ (self , parameters : Iterable [torch .nn .Parameter ], decay = 0.9999 ):
243
+ parameters = list (parameters )
244
+ self .shadow_params = [p .clone ().detach () for p in parameters ]
250
245
251
246
self .decay = decay
252
-
253
- if device is not None :
254
- self .averaged_model = self .averaged_model .to (device = device )
255
-
256
247
self .optimization_step = 0
257
248
258
249
def get_decay (self , optimization_step ):
@@ -263,34 +254,47 @@ def get_decay(self, optimization_step):
263
254
return 1 - min (self .decay , value )
264
255
265
256
@torch .no_grad ()
266
- def step (self , new_model ):
267
- ema_state_dict = self . averaged_model . state_dict ( )
257
+ def step (self , parameters ):
258
+ parameters = list ( parameters )
268
259
269
260
self .optimization_step += 1
270
261
self .decay = self .get_decay (self .optimization_step )
271
262
272
- for key , param in new_model .named_parameters ():
273
- if isinstance (param , dict ):
274
- continue
275
- try :
276
- ema_param = ema_state_dict [key ]
277
- except KeyError :
278
- ema_param = param .float ().clone () if param .ndim == 1 else copy .deepcopy (param )
279
- ema_state_dict [key ] = ema_param
280
-
281
- param = param .clone ().detach ().to (ema_param .dtype ).to (ema_param .device )
282
-
263
+ for s_param , param in zip (self .shadow_params , parameters ):
283
264
if param .requires_grad :
284
- ema_state_dict [key ].sub_ (self .decay * (ema_param - param ))
265
+ tmp = self .decay * (s_param - param )
266
+ s_param .sub_ (tmp )
285
267
else :
286
- ema_state_dict [key ].copy_ (param )
287
-
288
- for key , param in new_model .named_buffers ():
289
- ema_state_dict [key ] = param
268
+ s_param .copy_ (param )
290
269
291
- self .averaged_model .load_state_dict (ema_state_dict , strict = False )
292
270
torch .cuda .empty_cache ()
293
271
272
+ def copy_to (self , parameters : Iterable [torch .nn .Parameter ]) -> None :
273
+ """
274
+ Copy current averaged parameters into given collection of parameters.
275
+
276
+ Args:
277
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
278
+ updated with the stored moving averages. If `None`, the
279
+ parameters with which this `ExponentialMovingAverage` was
280
+ initialized will be used.
281
+ """
282
+ parameters = list (parameters )
283
+ for s_param , param in zip (self .shadow_params , parameters ):
284
+ param .data .copy_ (s_param .data )
285
+
286
+ def to (self , device = None , dtype = None ) -> None :
287
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
288
+
289
+ Args:
290
+ device: like `device` argument to `torch.Tensor.to`
291
+ """
292
+ # .to() on the tensors handles None correctly
293
+ self .shadow_params = [
294
+ p .to (device = device , dtype = dtype ) if p .is_floating_point () else p .to (device = device )
295
+ for p in self .shadow_params
296
+ ]
297
+
294
298
295
299
def main ():
296
300
args = parse_args ()
@@ -336,9 +340,6 @@ def main():
336
340
vae = AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" )
337
341
unet = UNet2DConditionModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "unet" )
338
342
339
- if args .use_ema :
340
- ema_unet = EMAModel (unet )
341
-
342
343
# Freeze vae and text_encoder
343
344
vae .requires_grad_ (False )
344
345
text_encoder .requires_grad_ (False )
@@ -510,8 +511,9 @@ def collate_fn(examples):
510
511
text_encoder .to (accelerator .device , dtype = weight_dtype )
511
512
vae .to (accelerator .device , dtype = weight_dtype )
512
513
513
- # Move the ema_unet to gpu.
514
- ema_unet .averaged_model .to (accelerator .device )
514
+ # Create EMA for the unet.
515
+ if args .use_ema :
516
+ ema_unet = EMAModel (unet .parameters ())
515
517
516
518
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
517
519
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -583,7 +585,7 @@ def collate_fn(examples):
583
585
# Checks if the accelerator has performed an optimization step behind the scenes
584
586
if accelerator .sync_gradients :
585
587
if args .use_ema :
586
- ema_unet .step (unet )
588
+ ema_unet .step (unet . parameters () )
587
589
progress_bar .update (1 )
588
590
global_step += 1
589
591
accelerator .log ({"train_loss" : train_loss }, step = global_step )
@@ -598,10 +600,14 @@ def collate_fn(examples):
598
600
# Create the pipeline using the trained modules and save it.
599
601
accelerator .wait_for_everyone ()
600
602
if accelerator .is_main_process :
603
+ unet = accelerator .unwrap_model (unet )
604
+ if args .use_ema :
605
+ ema_unet .copy_to (unet .parameters ())
606
+
601
607
pipeline = StableDiffusionPipeline (
602
608
text_encoder = text_encoder ,
603
609
vae = vae ,
604
- unet = accelerator . unwrap_model ( ema_unet . averaged_model if args . use_ema else unet ) ,
610
+ unet = unet ,
605
611
tokenizer = tokenizer ,
606
612
scheduler = PNDMScheduler (
607
613
beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = "scaled_linear" , skip_prk_steps = True
0 commit comments