Skip to content

Commit 9f8b13f

Browse files
patil-surajPrathik Rao
authored andcommitted
[train_text2image] Fix EMA and make it compatible with deepspeed. (huggingface#813)
* fix ema * style * add comment about copy * style * quality
1 parent 1b6749e commit 9f8b13f

File tree

1 file changed

+46
-40
lines changed

1 file changed

+46
-40
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import argparse
2-
import copy
32
import logging
43
import math
54
import os
65
import random
76
from pathlib import Path
8-
from typing import Optional
7+
from typing import Iterable, Optional
98

109
import numpy as np
1110
import torch
@@ -234,25 +233,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
234233
}
235234

236235

236+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
237237
class EMAModel:
238238
"""
239239
Exponential Moving Average of models weights
240240
"""
241241

242-
def __init__(
243-
self,
244-
model,
245-
decay=0.9999,
246-
device=None,
247-
):
248-
self.averaged_model = copy.deepcopy(model).eval()
249-
self.averaged_model.requires_grad_(False)
242+
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
243+
parameters = list(parameters)
244+
self.shadow_params = [p.clone().detach() for p in parameters]
250245

251246
self.decay = decay
252-
253-
if device is not None:
254-
self.averaged_model = self.averaged_model.to(device=device)
255-
256247
self.optimization_step = 0
257248

258249
def get_decay(self, optimization_step):
@@ -263,34 +254,47 @@ def get_decay(self, optimization_step):
263254
return 1 - min(self.decay, value)
264255

265256
@torch.no_grad()
266-
def step(self, new_model):
267-
ema_state_dict = self.averaged_model.state_dict()
257+
def step(self, parameters):
258+
parameters = list(parameters)
268259

269260
self.optimization_step += 1
270261
self.decay = self.get_decay(self.optimization_step)
271262

272-
for key, param in new_model.named_parameters():
273-
if isinstance(param, dict):
274-
continue
275-
try:
276-
ema_param = ema_state_dict[key]
277-
except KeyError:
278-
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
279-
ema_state_dict[key] = ema_param
280-
281-
param = param.clone().detach().to(ema_param.dtype).to(ema_param.device)
282-
263+
for s_param, param in zip(self.shadow_params, parameters):
283264
if param.requires_grad:
284-
ema_state_dict[key].sub_(self.decay * (ema_param - param))
265+
tmp = self.decay * (s_param - param)
266+
s_param.sub_(tmp)
285267
else:
286-
ema_state_dict[key].copy_(param)
287-
288-
for key, param in new_model.named_buffers():
289-
ema_state_dict[key] = param
268+
s_param.copy_(param)
290269

291-
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
292270
torch.cuda.empty_cache()
293271

272+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
273+
"""
274+
Copy current averaged parameters into given collection of parameters.
275+
276+
Args:
277+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
278+
updated with the stored moving averages. If `None`, the
279+
parameters with which this `ExponentialMovingAverage` was
280+
initialized will be used.
281+
"""
282+
parameters = list(parameters)
283+
for s_param, param in zip(self.shadow_params, parameters):
284+
param.data.copy_(s_param.data)
285+
286+
def to(self, device=None, dtype=None) -> None:
287+
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
288+
289+
Args:
290+
device: like `device` argument to `torch.Tensor.to`
291+
"""
292+
# .to() on the tensors handles None correctly
293+
self.shadow_params = [
294+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
295+
for p in self.shadow_params
296+
]
297+
294298

295299
def main():
296300
args = parse_args()
@@ -336,9 +340,6 @@ def main():
336340
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
337341
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
338342

339-
if args.use_ema:
340-
ema_unet = EMAModel(unet)
341-
342343
# Freeze vae and text_encoder
343344
vae.requires_grad_(False)
344345
text_encoder.requires_grad_(False)
@@ -510,8 +511,9 @@ def collate_fn(examples):
510511
text_encoder.to(accelerator.device, dtype=weight_dtype)
511512
vae.to(accelerator.device, dtype=weight_dtype)
512513

513-
# Move the ema_unet to gpu.
514-
ema_unet.averaged_model.to(accelerator.device)
514+
# Create EMA for the unet.
515+
if args.use_ema:
516+
ema_unet = EMAModel(unet.parameters())
515517

516518
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
517519
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -583,7 +585,7 @@ def collate_fn(examples):
583585
# Checks if the accelerator has performed an optimization step behind the scenes
584586
if accelerator.sync_gradients:
585587
if args.use_ema:
586-
ema_unet.step(unet)
588+
ema_unet.step(unet.parameters())
587589
progress_bar.update(1)
588590
global_step += 1
589591
accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -598,10 +600,14 @@ def collate_fn(examples):
598600
# Create the pipeline using the trained modules and save it.
599601
accelerator.wait_for_everyone()
600602
if accelerator.is_main_process:
603+
unet = accelerator.unwrap_model(unet)
604+
if args.use_ema:
605+
ema_unet.copy_to(unet.parameters())
606+
601607
pipeline = StableDiffusionPipeline(
602608
text_encoder=text_encoder,
603609
vae=vae,
604-
unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet),
610+
unet=unet,
605611
tokenizer=tokenizer,
606612
scheduler=PNDMScheduler(
607613
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True

0 commit comments

Comments
 (0)