Skip to content

Commit 9f5ecb6

Browse files
YaYaBPrathik Rao
authored andcommitted
Fix push_to_hub for dreambooth and textual_inversion (huggingface#748)
* Fix push_to_hub for dreambooth and textual_inversion * Use repo.push_to_hub instead of push_to_hub
1 parent 81d39a5 commit 9f5ecb6

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,7 @@ def collate_fn(examples):
575575
pipeline.save_pretrained(args.output_dir)
576576

577577
if args.push_to_hub:
578-
repo.push_to_hub(
579-
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
580-
)
578+
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
581579

582580
accelerator.end_training()
583581

examples/textual_inversion/textual_inversion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,7 @@ def main():
569569
save_progress(text_encoder, placeholder_token_id, accelerator, args)
570570

571571
if args.push_to_hub:
572-
repo.push_to_hub(
573-
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
574-
)
572+
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
575573

576574
accelerator.end_training()
577575

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from accelerate.logging import get_logger
1010
from datasets import load_dataset
1111
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
12-
from diffusers.hub_utils import init_git_repo, push_to_hub
12+
from diffusers.hub_utils import init_git_repo
1313
from diffusers.optimization import get_scheduler
1414
from diffusers.training_utils import EMAModel
1515
from torchvision.transforms import (
@@ -190,7 +190,7 @@ def transforms(examples):
190190
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
191191
# save the model
192192
if args.push_to_hub:
193-
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
193+
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
194194
else:
195195
pipeline.save_pretrained(args.output_dir)
196196
accelerator.wait_for_everyone()

0 commit comments

Comments
 (0)