Skip to content

Dreambooth broken, possibly because of ADAM optimizer, possibly more. #712

Closed
@affableroots

Description

@affableroots

I think Huggingface's Dreambooth is the only popular SD implementation that also uses Prior Preservation Loss, so I've been motivated to get it working, but the results have been terrible, and the entire model degrades, regardless of: # timesteps, learning rate, PPL turned on/off, # instance samples, # class regularization samples, etc. I've read the paper, and found that they actually unfreeze everything including the text embedder (and VAE? I'm not sure so I leave it frozen), so I implemented textual inversion within the dreambooth example (new token, unfreeze a single row of the embedder), which improves results considerably, but the whole model still degrades no matter what.

Someone smarter than me can confirm, but I think the culprit is ADAM:

optimizer_class = torch.optim.AdamW

My hypothesis is that since ADAM tries to drag all weights of unet etc. to 0, it ruins parts of the model that aren't concurrently being trained during the finetuning.

I've tested with weight_decay set to 0, and results seem considerably better, but I think the entire model is still degrading. I'm trying SGD next, so, fingers crossed, but there may still be some dragon lurking in the depths even despite removing ADAM.

A point of reference on this journey is the JoePenna "Dreambooth" library which doesn't implement PPL, and yet preserves priors much much better than this example, not to mention it learns the instance better, and is far more editable, and preserves out-of-class rather well. I expect more from this huggingface dreambooth example, and I'm trying to find why it's not delivering.

Any thoughts or guidance?

EDIT1A: SGD didn't learn the instance at 1000 steps + lr=5e-5, but it definitely preserved the priors way better (upon visual inspection. The loss really doesn't decrease much in any of my inversion/dreambooth experiments).

EDIT1B: Another test failed to learn using SGD at 1500 steps + lr=1e-3 + momentum=0.9. It might be trying to learn, but, not much. Priors were nicely preserved though still.

EDIT1C: 1500 * lr=5e2 learned great, was editable, didn't destroy other priors!!!

EDIT2: JoePenna seems to use AdamW, so I'm not sure what's up anymore, but I'm still getting quite poor results training with this library's (huggingface's) DB example.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions