Skip to content

Fix EMA and make it compatible with deepspeed. #813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 12, 2022
Merged

Conversation

patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Oct 12, 2022

There's an issue with current EMA in multi-gpu and deepspeed. This PR updates the EMAModel to only keep the parameters instead of copying the model which doesn't seem to work with deepspeed.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 12, 2022

The documentation is not available anymore as the PR was closed or merged.

if device is not None:
self.averaged_model = self.averaged_model.to(device=device)
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never heard them being called shadow, but pretty creative :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic looks good to me!
@patil-suraj make sure it works with fp16 (if it's an option) for the whole training run. Subtracting the decayed parameters might become numerically unstable when getting close to 0.9999

@patil-suraj
Copy link
Contributor Author

Will do a run in fp16, but our fp16 is mixed precision, so params are always fp32

@pink-red
Copy link
Contributor

Maybe this would be helpful. I've got DeepSpeed working on my 12 GB 3060 by changing this line (like in https://github.com/huggingface/diffusers/pull/735/files#diff-8702f762e46a3b5363085930b0b045de554909d32560864031ca7b12ddd349d5R555):

diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index e4a91ff..4481951 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -566,7 +566,7 @@ def main():
 
                 # Predict the noise residual and compute loss
                 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-                loss = F.mse_loss(noise_pred, noise, reduction="mean")
+                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
 
                 # Gather the losses across all processes for logging (if we use distributed training).
                 avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()

By "working" I mean that it doesn't throw an exception "Found dtype Float but expected Half" at this line:

accelerator.backward(loss)

I should say that I'm still training the model on the pokemon dataset, so I don't know what the actual result would be yet.

The command I've used for training is almost identical to the one in readme, I've only added accelerate launch params:

accelerate launch --use_deepspeed --zero_stage=2 --gradient_accumulation_steps=1 --offload_param_device=cpu --offload_optimizer_device=cpu train_text_to_image.py \
  --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
  --dataset_name="lambdalabs/pokemon-blip-captions" \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="sd-pokemon-model"

@patil-suraj
Copy link
Contributor Author

Thanks a lot @pink-red , would you like to open a PR for that once this is merged. Indeed, casting it float is required for deepspeed.

@pink-red
Copy link
Contributor

@patil-suraj No problem! 👌

@patrickvonplaten
Copy link
Contributor

Referring my review here to @anton-l as he knows EMA much better :-)

@patrickvonplaten
Copy link
Contributor

Also @patil-suraj let's maybe fix the code quality with make style no?

@patil-suraj patil-suraj merged commit 008b608 into main Oct 12, 2022
@patil-suraj patil-suraj deleted the txt2imag2-fix-eam branch October 12, 2022 17:13
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
…ggingface#813)

* fix ema

* style

* add comment about copy

* style

* quality
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants