diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index cd4f1b3b1ad2..f5fdf74449bb 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -312,7 +312,7 @@ def test_device_and_dtype_assignment(self): _ = self.model_fp16.float() # Check that this does not throw an error - _ = self.model_fp16.cuda() + _ = self.model_fp16.to(torch_device) class Bnb8bitDeviceTests(Base8bitTests):