Skip to content

Commit e66d4aa

Browse files
apolinariopatrickvonplaten
authored andcommitted
Change fp16 error to warning (huggingface#764)
* Swap fp16 error to warning Also remove the associated test * Formatting * warn -> warning * Update src/diffusers/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 2d7d98e commit e66d4aa

File tree

2 files changed

+6
-15
lines changed

2 files changed

+6
-15
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,12 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
172172
module = getattr(self, name)
173173
if isinstance(module, torch.nn.Module):
174174
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
175-
raise ValueError(
176-
"Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` "
177-
"due to the lack of support for `float16` operations on those devices in PyTorch. "
178-
"Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device."
175+
logger.warning(
176+
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
177+
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
178+
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
179+
" `float16` operations on those devices in PyTorch. Please remove the"
180+
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
179181
)
180182
module.to(torch_device)
181183
return self

tests/test_pipelines.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -247,17 +247,6 @@ def to(self, device):
247247

248248
return extract
249249

250-
def test_pipeline_fp16_cpu_error(self):
251-
model = self.dummy_uncond_unet
252-
scheduler = DDPMScheduler(num_train_timesteps=10)
253-
pipe = DDIMPipeline(model.half(), scheduler)
254-
255-
if str(torch_device) in ["cpu", "mps"]:
256-
self.assertRaises(ValueError, pipe.to, torch_device)
257-
else:
258-
# moving the pipeline to GPU should work
259-
pipe.to(torch_device)
260-
261250
def test_ddim(self):
262251
unet = self.dummy_uncond_unet
263252
scheduler = DDIMScheduler()

0 commit comments

Comments
 (0)