Skip to content

Commit e7c84ee

Browse files
Change fp16 error to warning (huggingface#764)
* Swap fp16 error to warning Also remove the associated test * Formatting * warn -> warning * Update src/diffusers/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent ae44c99 commit e7c84ee

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

pipeline_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,12 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
169169
module = getattr(self, name)
170170
if isinstance(module, torch.nn.Module):
171171
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
172-
raise ValueError(
173-
"Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` "
174-
"due to the lack of support for `float16` operations on those devices in PyTorch. "
175-
"Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device."
172+
logger.warning(
173+
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
174+
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
175+
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
176+
" `float16` operations on those devices in PyTorch. Please remove the"
177+
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
176178
)
177179
module.to(torch_device)
178180
return self

0 commit comments

Comments
 (0)