Skip to content

Commit 6dfaec3

Browse files
yiyixuxubaymax591
andauthored
make style for #10368 (#10370)
* fix bug for torch.uint1-7 not support in torch<2.6 * up --------- Co-authored-by: baymax591 <cbai@mail.nwpu.edu.cn>
1 parent c1e7fd5 commit 6dfaec3

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from packaging import version
2525

26-
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
26+
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
2727
from ..base import DiffusersQuantizer
2828

2929

@@ -35,21 +35,28 @@
3535
import torch
3636
import torch.nn as nn
3737

38-
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
39-
# At the moment, only int8 is supported for integer quantization dtypes.
40-
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
41-
# to support more quantization methods, such as intx_weight_only.
42-
torch.int8,
43-
torch.float8_e4m3fn,
44-
torch.float8_e5m2,
45-
torch.uint1,
46-
torch.uint2,
47-
torch.uint3,
48-
torch.uint4,
49-
torch.uint5,
50-
torch.uint6,
51-
torch.uint7,
52-
)
38+
if is_torch_version(">=", "2.5"):
39+
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
40+
# At the moment, only int8 is supported for integer quantization dtypes.
41+
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
42+
# to support more quantization methods, such as intx_weight_only.
43+
torch.int8,
44+
torch.float8_e4m3fn,
45+
torch.float8_e5m2,
46+
torch.uint1,
47+
torch.uint2,
48+
torch.uint3,
49+
torch.uint4,
50+
torch.uint5,
51+
torch.uint6,
52+
torch.uint7,
53+
)
54+
else:
55+
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
56+
torch.int8,
57+
torch.float8_e4m3fn,
58+
torch.float8_e5m2,
59+
)
5360

5461
if is_torchao_available():
5562
from torchao.quantization import quantize_

0 commit comments

Comments
 (0)