Skip to content

text-to-image trainning resume from checkpoints get an error #3871

Closed
@yijinsheng

Description

@yijinsheng

Describe the bug

I try to train a text-to-image model by using the script diffusers/examples/text_to_image
it success when I train it from scratch
but when i add the param --resume_from_checkpoint "checkpoint-10000"
I get the error
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Reproduction

  1. use the script to train a text-to-image model
export MODEL_NAME="/root/yjs/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c"
export TRAIN_DIR="/root/yjs/train_data"
export OUTPUT_DIR="./out_models"

accelerate launch train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --use_ema \
  --resolution=768 --center_crop --random_flip \
  --train_batch_size=8 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir=${OUTPUT_DIR} \
  --image_column="image" --caption_column='additional_feature'\
  --enable_xformers_memory_efficient_attention \
  --checkpointing_steps 1000 
  1. I get many checkpoints folders in my out_models directory ,and I choose one and add the param --resume_from_checkpoint "checkpoint-10000" and run the following script
export MODEL_NAME="/root/yjs/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c"
export TRAIN_DIR="/root/yjs/train_data"
export OUTPUT_DIR="./out_models"

accelerate launch train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --use_ema \
  --resolution=768 --center_crop --random_flip \
  --train_batch_size=8 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir=${OUTPUT_DIR} \
  --image_column="image" --caption_column='additional_feature'\
  --enable_xformers_memory_efficient_attention \
  --checkpointing_steps 1000 \
  --resume_from_checkpoint "checkpoint-10000"

Logs

─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /root/yjs/diffusers/examples/text_to_image/train_text_to_image.py:792 in <module>                │
│                                                                                                  │
│   789                                                                                            │
│   790                                                                                            │
│   791 if __name__ == "__main__":                                                                 │
│ ❱ 792 │   main()                                                                                 │
│   793                                                                                            │
│                                                                                                  │
│ /root/yjs/diffusers/examples/text_to_image/train_text_to_image.py:751 in main                    │
│                                                                                                  │
│   748 │   │   │   # Checks if the accelerator has performed an optimization step behind the sc   │
│   749 │   │   │   if accelerator.sync_gradients:                                                 │
│   750 │   │   │   │   if args.use_ema:                                                           │
│ ❱ 751 │   │   │   │   │   ema_unet.step(unet.parameters())                                       │
│   752 │   │   │   │   progress_bar.update(1)                                                     │
│   753 │   │   │   │   global_step += 1                                                           │
│   754 │   │   │   │   accelerator.log({"train_loss": train_loss}, step=global_step)              │
│                                                                                                  │
│ /root/.local/conda/envs/sd/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in        │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /root/yjs/diffusers/examples/text_to_image/train_text_to_image.py:320 in step                    │
│                                                                                                  │
│   317 │   │                                                                                      │
│   318 │   │   for s_param, param in zip(self.shadow_params, parameters):                         │
│   319 │   │   │   if param.requires_grad:                                                        │
│ ❱ 320 │   │   │   │   s_param.sub_(one_minus_decay * (s_param - param))                          │
│   321 │   │   │   else:                                                                          │
│   322 │   │   │   │   s_param.copy_(param)                                                       │
│   323                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Steps:   1%|| 7/1000 [00:25<1:01:21,  3.71s/it, lr=1e-5, step_loss=0.372]

System Info

1.system info

  • Tesla V100 NVIDIA-SMI 450.51.05 Driver Version: 450.51.05 CUDA Version: 11.0
  • Linux dl-1626155627-pod-jupyter-67f8849f59-k8tff 3.10.0-1062.el7.bclinux.x86_64 Add glide modeling files #1 SMP Thu Mar 5 14:02:53 CST 2020 x86_64 x86_64 x86_64 GNU/Linux

2.python env

  • diffusers version: 0.16.1
  • Platform: Linux-3.10.0-1062.el7.bclinux.x86_64-x86_64-with-glibc2.27
  • Python version: 3.10.11
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Huggingface_hub version: 0.14.1
  • Transformers version: 4.29.2
  • Accelerate version: 0.19.0
  • xFormers version: 0.0.20
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions