Skip to content

[LoRA] fix: lora loading when using with a device_mapped model. #9449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dc1aee2
fix: lora loading when using with a device_mapped model.
sayakpaul Sep 17, 2024
949a929
better attibutung
sayakpaul Sep 17, 2024
64b3ad1
empty
sayakpaul Sep 17, 2024
6d03c12
Merge branch 'main' into lora-device-map
sayakpaul Sep 22, 2024
d4bd94b
Merge branch 'main' into lora-device-map
sayakpaul Sep 24, 2024
5479198
Apply suggestions from code review
sayakpaul Sep 24, 2024
2846549
Merge branch 'main' into lora-device-map
sayakpaul Sep 27, 2024
1ed0eb0
Merge branch 'main' into lora-device-map
sayakpaul Sep 28, 2024
d2d59c3
Merge branch 'main' into lora-device-map
sayakpaul Oct 2, 2024
5f3cae2
Merge branch 'main' into lora-device-map
sayakpaul Oct 6, 2024
8f670e2
Merge branch 'main' into lora-device-map
sayakpaul Oct 8, 2024
e42ec19
Merge branch 'main' into lora-device-map
sayakpaul Oct 10, 2024
f63b04c
Merge branch 'main' into lora-device-map
sayakpaul Oct 15, 2024
eefda54
Merge branch 'main' into lora-device-map
sayakpaul Oct 19, 2024
ea727a3
minors
sayakpaul Oct 19, 2024
71989e3
better error messages.
sayakpaul Oct 19, 2024
f62afac
fix-copies
sayakpaul Oct 19, 2024
2334f78
add: tests, docs.
sayakpaul Oct 19, 2024
5ea1173
add hardware note.
sayakpaul Oct 19, 2024
f64751e
Merge branch 'main' into lora-device-map
sayakpaul Oct 19, 2024
c0dee87
quality
sayakpaul Oct 19, 2024
4b6124a
Merge branch 'main' into lora-device-map
sayakpaul Oct 22, 2024
fe2cca8
Update docs/source/en/training/distributed_inference.md
sayakpaul Oct 23, 2024
2db5d48
Merge branch 'main' into lora-device-map
sayakpaul Oct 23, 2024
61903c8
Merge branch 'main' into lora-device-map
sayakpaul Oct 31, 2024
03377b7
fixes
sayakpaul Oct 31, 2024
0bd40cb
skip properly.
sayakpaul Oct 31, 2024
a61b754
fixes
sayakpaul Oct 31, 2024
ccd8d2a
resolve conflicts.
sayakpaul Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
delete_adapter_layers,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_transformers_available,
logging,
Expand Down Expand Up @@ -214,9 +215,18 @@ def _optionally_disable_offloading(cls, _pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
Expand Down
38 changes: 38 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,11 @@ def to(self, *args, **kwargs):

device = device or device_arg

def model_has_device_map(model):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 it would make sense to make this a separate utility instead of having redefine three times. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, you can add as a util function inside pipeline_utils.

if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
Expand All @@ -403,6 +408,13 @@ def module_is_offloaded(module):

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

# device-mapped modules should not go through any device placements.
pipeline_has_device_mapped_modules = any(model_has_device_map(module) for _, module in self.components.items())
if pipeline_has_device_mapped_modules:
raise ValueError(
"It seems like you have device-mapped modules in the pipeline which doesn't allow explicit device placement using `to()`."
)

# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
Expand Down Expand Up @@ -994,6 +1006,19 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

# device-mapped modules should not go through any device placements.
pipeline_has_device_mapped_modules = any(model_has_device_map(module) for _, module in self.components.items())
if pipeline_has_device_mapped_modules:
raise ValueError(
"It seems like you have device-mapped modules in the pipeline which doesn't allow explicit device placement using `to()`."
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
Expand Down Expand Up @@ -1087,6 +1112,19 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

# device-mapped modules should not go through any device placements.
pipeline_has_device_mapped_modules = any(model_has_device_map(module) for _, module in self.components.items())
if pipeline_has_device_mapped_modules:
raise ValueError(
"It seems like you have device-mapped modules in the pipeline which doesn't allow explicit device placement using `to()`."
)

if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
Expand Down
Loading