Skip to content

[npu]sd3 dreambooth adapt for npu #726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion ppdiffusers/examples/dreambooth/README_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,37 @@ pipeline.load_lora_weights('your-lora-checkpoint')

image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
image.save("sks_dog_dreambooth_lora.png")
```
```

## NPU硬件训练
1. 请先参照[PaddleCustomDevice](https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/npu/README_cn.md)安装NPU硬件Paddle
2. 使用NPU进行LoRA训练和推理时参考如下命令设置相应的环境变量,训练和推理运行命令可直接参照上述LoRA训练和推理命令。

使用NPU进行LoRA训练和推理时参考如下命令设置相应的环境变量,训练和推理运行命令可直接参照上述LoRA训练和推理命令。
```bash
export FLAGS_npu_storage_format=0
export FLAGS_use_stride_kernel=0
export FLAGS_npu_scale_aclnn=True
export FLAGS_allocator_strategy=auto_growth
```
训练(DreamBooth微调)时如果显存不够,可以尝试添加参数(训练完成后不进行评测)`not_validation_final`, 并去除`validation_prompt`,具体命令如下所示
```
python train_dreambooth_sd3.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=50 \
--validation_epochs=25 \
--seed="0" \
--checkpointing_steps=250 \
--not_validation_final
```
63 changes: 35 additions & 28 deletions ppdiffusers/examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ def parse_args(input_args=None):
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument(
"--not_validation_final",
default=False,
action="store_true",
help="Flag to not validation when train finish in order to save memory.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--num_class_images",
Expand Down Expand Up @@ -1525,36 +1531,37 @@ def get_state_dict(model):
save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers
)

pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
paddle_dtype=weight_dtype,
)
if not args.not_validation_final:
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
paddle_dtype=weight_dtype,
)

# Final inference
# Load previous pipeline
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
paddle_dtype=weight_dtype,
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)

# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
# Final inference
# Load previous pipeline
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
paddle_dtype=weight_dtype,
)
# load attention processors
pipeline.load_lora_weights(args.output_dir)

# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
)

accelerator.end_training()

Expand Down
80 changes: 44 additions & 36 deletions ppdiffusers/examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,12 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--not_validation_final",
default=False,
action="store_true",
help="Flag to not validation when train finish in order to save memory.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1583,48 +1589,50 @@ def get_state_dict(model):

accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
if not args.not_validation_final:
transformer = unwrap_model(transformer)

if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_three = unwrap_model(text_encoder_three)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=transformer,
text_encoder=text_encoder_one,
text_encoder_2=text_encoder_two,
text_encoder_3=text_encoder_three,
)
else:
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path, transformer=transformer
)
if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_three = unwrap_model(text_encoder_three)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=transformer,
text_encoder=text_encoder_one,
text_encoder_2=text_encoder_two,
text_encoder_3=text_encoder_three,
)
else:
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path, transformer=transformer
)

# save the pipeline
pipeline.save_pretrained(args.output_dir)
# save the pipeline
pipeline.save_pretrained(args.output_dir)

# Final inference
# Load previous pipeline
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.output_dir,
revision=args.revision,
variant=args.variant,
paddle_dtype=weight_dtype,
)

# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
# if not args.not_validation_final:
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.output_dir,
revision=args.revision,
variant=args.variant,
paddle_dtype=weight_dtype,
)

# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
)

accelerator.end_training()

Expand Down