diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 52ab289effec..74047e1006bd 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -157,6 +157,8 @@ title: Getting Started - local: quantization/bitsandbytes title: bitsandbytes + - local: quantization/torchao + title: torchao title: Quantization Methods - sections: - local: optimization/fp16 diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 2fbde9e707ea..18aadf3111bd 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui [[autodoc]] BitsAndBytesConfig +## TorchAoConfig + +[[autodoc]] TorchAoConfig + ## DiffusersQuantizer [[autodoc]] quantizers.base.DiffusersQuantizer diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index d8adbc85a259..151b22a607a4 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be ## When to use what? -This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file +Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use. \ No newline at end of file diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md new file mode 100644 index 000000000000..bd5c7697a0f7 --- /dev/null +++ b/docs/source/en/quantization/torchao.md @@ -0,0 +1,92 @@ + + +# torchao + +[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more. + +Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed. + +```bash +pip install -U torch torchao +``` + + +Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. + +The example below only quantizes the weights to int8. + +```python +from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig + +model_id = "black-forest-labs/Flux.1-Dev" +dtype = torch.bfloat16 + +quantization_config = TorchAoConfig("int8wo") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=dtype, +) +pipe = FluxPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=dtype, +) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0] +image.save("output.png") +``` + +TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code. + +```python +# In the above code, add the following after initializing the transformer +transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) +``` + +For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware. + +torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future. + +The `TorchAoConfig` class accepts three parameters: +- `quant_type`: A string value mentioning one of the quantization types below. +- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`. +- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. + +## Supported quantization types + +torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7. + +Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. + +Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly. + +The quantization methods supported are as follows: + +| **Category** | **Full Function Names** | **Shorthands** | +|--------------|-------------------------|----------------| +| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` | +| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` | +| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` | +| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` | + +Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations. + +Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. + +## Resources + +- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) +- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ae4ef299abb3..57be6f853215 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,7 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": ["BitsAndBytesConfig"], + "quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -562,7 +562,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - from .quantizers.quantization_config import BitsAndBytesConfig + from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig try: if not is_onnx_available(): diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 751117f8f247..546c0eb4d840 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -25,7 +25,6 @@ import torch from huggingface_hub.utils import EntryNotFoundError -from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, @@ -182,7 +181,6 @@ def load_model_dict_into_meta( device = device or torch.device("cpu") dtype = dtype or torch.float32 is_quantized = hf_quantizer is not None - is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) empty_state_dict = model.state_dict() @@ -215,12 +213,12 @@ def load_model_dict_into_meta( # bnb params are flattened. if empty_state_dict[param_name].shape != param.shape: if ( - is_quant_method_bnb + is_quantized and hf_quantizer.pre_quantized and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) - elif not is_quant_method_bnb: + else: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4fe457706473..ce5289e3dbfd 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -700,10 +700,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: - if device_map is not None: + is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" + if is_bnb_quantization_method and device_map is not None: raise NotImplementedError( - "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." + "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." ) + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) @@ -858,13 +860,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is None and not is_sharded: # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. # It would error out during the `validate_environment()` call above in the absence of cuda. - is_quant_method_bnb = ( - getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - ) if hf_quantizer is None: param_device = "cpu" # TODO (sayakpaul, SunMarc): remove this after model loading refactor - elif is_quant_method_bnb: + else: param_device = torch.device(torch.cuda.current_device()) state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 97cbcdc0e53f..098308ae0bdc 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -19,17 +19,20 @@ from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer -from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod +from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig +from .torchao import TorchAoHfQuantizer AUTO_QUANTIZER_MAPPING = { "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, + "torchao": TorchAoHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, + "torchao": TorchAoConfig, } diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index f521c5d717d6..4aeb75ab704c 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -22,15 +22,17 @@ import copy import importlib.metadata +import inspect import json import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Union +from functools import partial +from typing import Any, Dict, List, Optional, Union from packaging import version -from ..utils import is_torch_available, logging +from ..utils import is_torch_available, is_torchao_available, logging if is_torch_available(): @@ -41,6 +43,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" + TORCHAO = "torchao" @dataclass @@ -389,3 +392,254 @@ def to_diff_dict(self) -> Dict[str, Any]: serializable_config_dict[key] = value return serializable_config_dict + + +@dataclass +class TorchAoConfig(QuantizationConfigMixin): + """This is a config class for torchao quantization/sparsity techniques. + + Args: + quant_type (`str`): + The type of quantization we want to use, currently supporting: + - **Integer quantization:** + - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, + `int8_weight_only`, `int8_dynamic_activation_int8_weight` + - Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq` + + - **Floating point 8-bit quantization:** + - Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`, + `float8_static_activation_float8_weight` + - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, + `float8_e4m3_tensor`, `float8_e4m3_row`, + + - **Floating point X-bit quantization:** + - Full function names: `fpx_weight_only` + - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number + of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must + be satisfied for a given shorthand notation. + + - **Unsigned Integer quantization:** + - Full function names: `uintx_weight_only` + - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` + modules_to_not_convert (`List[str]`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have some + modules left in their original precision. + kwargs (`Dict[str, Any]`, *optional*): + The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization + supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and + documentation of arguments can be found in + https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques + + Example: + ```python + from diffusers import FluxTransformer2DModel, TorchAoConfig + + quantization_config = TorchAoConfig("int8wo") + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + ``` + """ + + def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None: + self.quant_method = QuantizationMethod.TORCHAO + self.quant_type = quant_type + self.modules_to_not_convert = modules_to_not_convert + + # When we load from serialized config, "quant_type_kwargs" will be the key + if "quant_type_kwargs" in kwargs: + self.quant_type_kwargs = kwargs["quant_type_kwargs"] + else: + self.quant_type_kwargs = kwargs + + TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() + if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): + raise ValueError( + f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the " + f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] + signature = inspect.signature(method) + all_kwargs = { + param.name + for param in signature.parameters.values() + if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] + } + unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) + + if len(unsupported_kwargs) > 0: + raise ValueError( + f'The quantization method "{quant_type}" does not support the following keyword arguments: ' + f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." + ) + + @classmethod + def _get_torchao_quant_type_to_method(cls): + r""" + Returns supported torchao quantization types with all commonly used notations. + """ + + if is_torchao_available(): + # TODO(aryan): Support autoquant and sparsify + from torchao.quantization import ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + fpx_weight_only, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + uintx_weight_only, + ) + + # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers + from torchao.quantization.observer import PerRow, PerTensor + + def generate_float8dq_types(dtype: torch.dtype): + name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" + types = {} + + for granularity_cls in [PerTensor, PerRow]: + # Note: Activation and Weights cannot have different granularities + granularity_name = "tensor" if granularity_cls is PerTensor else "row" + types[f"float8dq_{name}_{granularity_name}"] = partial( + float8_dynamic_activation_float8_weight, + activation_dtype=dtype, + weight_dtype=dtype, + granularity=(granularity_cls(), granularity_cls()), + ) + + return types + + def generate_fpx_quantization_types(bits: int): + types = {} + + for ebits in range(1, bits): + mbits = bits - ebits - 1 + types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) + + non_sign_bits = bits - 1 + default_ebits = (non_sign_bits + 1) // 2 + default_mbits = non_sign_bits - default_ebits + types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) + + return types + + INT4_QUANTIZATION_TYPES = { + # int4 weight + bfloat16/float16 activation + "int4wo": int4_weight_only, + "int4_weight_only": int4_weight_only, + # int4 weight + int8 activation + "int4dq": int8_dynamic_activation_int4_weight, + "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight, + } + + INT8_QUANTIZATION_TYPES = { + # int8 weight + bfloat16/float16 activation + "int8wo": int8_weight_only, + "int8_weight_only": int8_weight_only, + # int8 weight + int8 activation + "int8dq": int8_dynamic_activation_int8_weight, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + } + + # TODO(aryan): handle torch 2.2/2.3 + FLOATX_QUANTIZATION_TYPES = { + # float8_e5m2 weight + bfloat16/float16 activation + "float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + "float8_weight_only": float8_weight_only, + "float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), + # float8_e4m3 weight + bfloat16/float16 activation + "float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), + # float8_e5m2 weight + float8 activation (dynamic) + "float8dq": float8_dynamic_activation_float8_weight, + "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, + # ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out. + # However, changing activation_dtype=torch.float8_e4m3 might work here ===== + # "float8dq_e5m2": partial( + # float8_dynamic_activation_float8_weight, + # activation_dtype=torch.float8_e5m2, + # weight_dtype=torch.float8_e5m2, + # ), + # **generate_float8dq_types(torch.float8_e5m2), + # ===== ===== + # float8_e4m3 weight + float8 activation (dynamic) + "float8dq_e4m3": partial( + float8_dynamic_activation_float8_weight, + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + ), + **generate_float8dq_types(torch.float8_e4m3fn), + # float8 weight + float8 activation (static) + "float8_static_activation_float8_weight": float8_static_activation_float8_weight, + # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly + # fpx weight + bfloat16/float16 activation + **generate_fpx_quantization_types(3), + **generate_fpx_quantization_types(4), + **generate_fpx_quantization_types(5), + **generate_fpx_quantization_types(6), + **generate_fpx_quantization_types(7), + } + + UINTX_QUANTIZATION_DTYPES = { + "uintx_weight_only": uintx_weight_only, + "uint1wo": partial(uintx_weight_only, dtype=torch.uint1), + "uint2wo": partial(uintx_weight_only, dtype=torch.uint2), + "uint3wo": partial(uintx_weight_only, dtype=torch.uint3), + "uint4wo": partial(uintx_weight_only, dtype=torch.uint4), + "uint5wo": partial(uintx_weight_only, dtype=torch.uint5), + "uint6wo": partial(uintx_weight_only, dtype=torch.uint6), + "uint7wo": partial(uintx_weight_only, dtype=torch.uint7), + # "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported + } + + QUANTIZATION_TYPES = {} + QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES) + QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) + QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) + + if cls._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) + + return QUANTIZATION_TYPES + else: + raise ValueError( + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + ) + + @staticmethod + def _is_cuda_capability_atleast_8_9() -> bool: + if not torch.cuda.is_available(): + raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.") + + major, minor = torch.cuda.get_device_capability() + if major == 8: + return minor >= 9 + return major >= 9 + + def get_apply_tensor_subclass(self): + TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() + return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs) + + def __repr__(self): + r""" + Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`: + + ``` + TorchAoConfig { + "modules_to_not_convert": null, + "quant_method": "torchao", + "quant_type": "uint_a16w4", + "quant_type_kwargs": { + "group_size": 32 + } + } + ``` + """ + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" diff --git a/src/diffusers/quantizers/torchao/__init__.py b/src/diffusers/quantizers/torchao/__init__.py new file mode 100644 index 000000000000..09e6a19d4df0 --- /dev/null +++ b/src/diffusers/quantizers/torchao/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .torchao_quantizer import TorchAoHfQuantizer diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py new file mode 100644 index 000000000000..8b28a403e6f0 --- /dev/null +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -0,0 +1,280 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py +""" + +import importlib +import types +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from packaging import version + +from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + import torch.nn as nn + + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) + +if is_torchao_available(): + from torchao.quantization import quantize_ + + +logger = logging.get_logger(__name__) + + +def _quantization_type(weight): + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + + if isinstance(weight, AffineQuantizedTensor): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + + if isinstance(weight, LinearActivationQuantizedTensor): + return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" + + +def _linear_extra_repr(self): + weight = _quantization_type(self.weight) + if weight is None: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" + else: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" + + +class TorchAoHfQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/. + """ + + requires_calibration = False + required_packages = ["torchao"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_torchao_available(): + raise ImportError( + "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`" + ) + + self.offload = False + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + if self.pre_quantized: + raise ValueError( + "You are attempting to perform cpu/disk offload with a pre-quantized torchao model " + "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." + ) + else: + self.offload = True + + if self.pre_quantized: + weights_only = kwargs.get("weights_only", None) + if weights_only: + torch_version = version.parse(importlib.metadata.version("torch")) + if torch_version < version.parse("2.5.0"): + # TODO(aryan): TorchAO is compatible with Pytorch >= 2.2 for certain quantization types. Try to see if we can support it in future + raise RuntimeError( + f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}." + ) + + def update_torch_dtype(self, torch_dtype): + quant_type = self.quantization_config.quant_type + + if quant_type.startswith("int"): + if torch_dtype is not None and torch_dtype != torch.bfloat16: + logger.warning( + f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " + f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." + ) + + if torch_dtype is None: + # We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op + logger.warning( + "Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " + "to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " + "dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning." + ) + torch_dtype = torch.bfloat16 + + return torch_dtype + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + quant_type = self.quantization_config.quant_type + + if quant_type.startswith("int8") or quant_type.startswith("int4"): + # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8 + return torch.int8 + elif quant_type == "uintx_weight_only": + return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) + elif quant_type.startswith("uint"): + return { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + }[int(quant_type[4])] + elif quant_type.startswith("float") or quant_type.startswith("fp"): + return torch.bfloat16 + + if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION): + return target_dtype + + # We need one of the supported dtypes to be selected in order for accelerate to determine + # the total size of modules/parameters for auto device placement. + possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"] + raise ValueError( + f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype " + f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the " + f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + max_memory = {key: val * 0.9 for key, val in max_memory.items()} + return max_memory + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + param_device = kwargs.pop("param_device", None) + # Check if the param_name is not in self.modules_to_not_convert + if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): + return False + elif param_device == "cpu" and self.offload: + # We don't quantize weights that we offload + return False + else: + # We only quantize the weight of nn.Linear + module, tensor_name = get_module_from_name(model, param_name) + return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: List[str], + ): + r""" + Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, + then we move it to the target device. Finally, we quantize the module. + """ + module, tensor_name = get_module_from_name(model, param_name) + + if self.pre_quantized: + # If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info + # about AffineQuantizedTensor + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + else: + # As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves + module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "ModelMixin"): + return model + + def is_serializable(self, safe_serialization=None): + # TODO(aryan): needs to be tested + if safe_serialization: + logger.warning( + "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." + ) + return False + + _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( + "0.25.0" + ) + + if not _is_torchao_serializable: + logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") + + if self.offload and self.quantization_config.modules_to_not_convert is None: + logger.warning( + "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them." + "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config." + ) + return False + + return _is_torchao_serializable + + @property + def is_trainable(self): + return self.quantization_config.quant_type.startswith("int8") diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f91cee8113f2..9860ac849834 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -87,6 +87,7 @@ is_torch_version, is_torch_xla_available, is_torch_xla_version, + is_torchao_available, is_torchsde_available, is_torchvision_available, is_transformers_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index e3b7655737a8..f325f36bddd3 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -340,6 +340,15 @@ def is_timm_available(): _imageio_available = False +_is_torchao_available = importlib.util.find_spec("torchao") is not None +if _is_torchao_available: + try: + _torchao_version = importlib_metadata.version("torchao") + logger.debug(f"Successfully import torchao version {_torchao_version}") + except importlib_metadata.PackageNotFoundError: + _is_torchao_available = False + + def is_torch_available(): return _torch_available @@ -460,6 +469,10 @@ def is_imageio_available(): return _imageio_available +def is_torchao_available(): + return _is_torchao_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -593,6 +606,11 @@ def is_imageio_available(): {0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg` """ +# docstyle-ignore +TORCHAO_IMPORT_ERROR = """ +{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao` +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -618,6 +636,7 @@ def is_imageio_available(): ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index b3e381f7d3fb..b4d3415de50e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -39,6 +39,7 @@ is_timm_available, is_torch_available, is_torch_version, + is_torchao_available, is_torchsde_available, is_transformers_available, ) @@ -476,6 +477,18 @@ def decorator(test_case): return decorator +def require_torchao_version_greater(torchao_version): + def decorator(test_case): + correct_torchao_version = is_torchao_available() and version.parse( + version.parse(importlib.metadata.version("torchao")).base_version + ) > version.parse(torchao_version) + return unittest.skipUnless( + correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}." + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/quantization/torchao/README.md b/tests/quantization/torchao/README.md new file mode 100644 index 000000000000..fadc529e12fc --- /dev/null +++ b/tests/quantization/torchao/README.md @@ -0,0 +1,53 @@ +The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/tests/quantization/torchao_integration/). + +The benchmarks were run on a single H100. Below is `nvidia-smi`: + +```bash ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 | +| N/A 34C P0 69W / 700W | 2MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| +| No running processes found | ++---------------------------------------------------------------------------------------+ +``` + +The benchmark results for Flux and CogVideoX can be found in [this](https://github.com/huggingface/diffusers/pull/10009) PR. + +The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent: + +```bash +HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests +``` + +`diffusers-cli`: + +```bash +- 🤗 Diffusers version: 0.32.0.dev0 +- Platform: Linux-5.15.0-1049-aws-x86_64-with-glibc2.31 +- Running on Google Colab?: No +- Python version: 3.10.14 +- PyTorch version (GPU?): 2.6.0.dev20241112+cu121 (False) +- Flax version (CPU?/GPU?/TPU?): not installed (NA) +- Jax version: not installed +- JaxLib version: not installed +- Huggingface_hub version: 0.26.2 +- Transformers version: 4.46.3 +- Accelerate version: 1.1.1 +- PEFT version: not installed +- Bitsandbytes version: not installed +- Safetensors version: 0.4.5 +- xFormers version: not installed +``` diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py new file mode 100644 index 000000000000..5c71fc4e0ae7 --- /dev/null +++ b/tests/quantization/torchao/test_torchao.py @@ -0,0 +1,625 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest +from typing import List + +import numpy as np +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, + TorchAoConfig, +) +from diffusers.models.attention_processor import Attention +from diffusers.utils.testing_utils import ( + enable_full_determinism, + is_torch_available, + is_torchao_available, + nightly, + require_torch, + require_torch_gpu, + require_torchao_version_greater, + slow, + torch_device, +) + + +enable_full_determinism() + + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +if is_torchao_available(): + from torchao.dtypes import AffineQuantizedTensor + from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + + +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +class TorchAoConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = TorchAoConfig("int4_weight_only") + torchao_orig_config = quantization_config.to_dict() + + for key in torchao_orig_config: + self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key]) + + def test_post_init_check(self): + """ + Test kwargs validations in TorchAoConfig + """ + _ = TorchAoConfig("int4_weight_only") + with self.assertRaisesRegex(ValueError, "is not supported yet"): + _ = TorchAoConfig("uint8") + + with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): + _ = TorchAoConfig("int4_weight_only", group_size1=32) + + def test_repr(self): + """ + Check that there is no error in the repr + """ + quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) + expected_repr = """TorchAoConfig { + "modules_to_not_convert": [ + "conv" + ], + "quant_method": "torchao", + "quant_type": "int4_weight_only", + "quant_type_kwargs": { + "group_size": 8 + } + }""".replace(" ", "").replace("\n", "") + quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") + self.assertEqual(quantization_repr, expected_repr) + + +# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +class TorchAoTest(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_components(self, quantization_config: TorchAoConfig): + model_id = "hf-internal-testing/tiny-flux-pipe" + transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") + text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 32, + "width": 32, + "num_inference_steps": 2, + "output_type": "np", + "generator": generator, + } + + return inputs + + def get_dummy_tensor_inputs(self, device=None, seed: int = 0): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + torch.manual_seed(seed) + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + + torch.manual_seed(seed) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device, dtype=torch.bfloat16) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + output_slice = output[-1, -1, -3:, -3:].flatten() + + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_quantization(self): + # fmt: off + QUANTIZATION_TYPES_TO_TEST = [ + ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), + ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), + ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), + ("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ] + + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), + ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), + # ===== + # The following lead to an internal torch error: + # RuntimeError: mat2 shape (32x4 must be divisible by 16 + # Skip these for now; TODO(aryan): investigate later + # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + # Cutlass fails to initialize for below + # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ]) + # fmt: on + + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quant_kwargs = {} + if quantization_name in ["uint4wo", "uint_a16w7"]: + # The dummy flux model that we use requires us to impose some restrictions on group_size here + quant_kwargs.update({"group_size": 16}) + quantization_config = TorchAoConfig( + quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs + ) + self._test_quant_type(quantization_config, expected_slice) + + def test_int4wo_quant_bfloat16_conversion(self): + """ + Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. + """ + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + weight = quantized_model.transformer_blocks[0].ff.net[2].weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + self.assertEqual(weight.quant_min, 0) + self.assertEqual(weight.quant_max, 15) + self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) + + def test_offload(self): + """ + Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies + that the device map is correctly set (in the `hf_device_map` attribute of the model). + """ + + device_map_offload = { + "time_text_embed": torch_device, + "context_embedder": torch_device, + "x_embedder": torch_device, + "transformer_blocks.0": "cpu", + "single_transformer_blocks.0": "disk", + "norm_out": torch_device, + "proj_out": "cpu", + } + + inputs = self.get_dummy_tensor_inputs(torch_device) + + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map_offload, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + self.assertTrue(quantized_model.hf_device_map == device_map_offload) + + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_modules_to_not_convert(self): + quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + + unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2] + self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) + self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) + self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) + + quantized_layer = quantized_model.proj_out + self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) + self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8) + + def test_training(self): + quantization_config = TorchAoConfig("int8_weight_only") + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + for param in quantized_model.parameters(): + # freeze the model as only adapter layers will be trained + param.requires_grad = False + if param.ndim == 1: + param.data = param.data.to(torch.float32) + + for _, module in quantized_model.named_modules(): + if isinstance(module, Attention): + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output.norm().backward() + + for module in quantized_model.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + @nightly + def test_torch_compile(self): + r"""Test that verifies if torch.compile works with torchao quantization.""" + quantization_config = TorchAoConfig("int8_weight_only") + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device, dtype=torch.bfloat16) + + inputs = self.get_dummy_inputs(torch_device) + normal_output = pipe(**inputs)[0].flatten()[-32:] + + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) + inputs = self.get_dummy_inputs(torch_device) + compile_output = pipe(**inputs)[0].flatten()[-32:] + + # Note: Seems to require higher tolerance + self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) + + @staticmethod + def _get_memory_footprint(module): + quantized_param_memory = 0.0 + unquantized_param_memory = 0.0 + + for param in module.parameters(): + if param.__class__.__name__ == "AffineQuantizedTensor": + data, scale, zero_point = param.layout_tensor.get_plain() + quantized_param_memory += data.numel() + data.element_size() + quantized_param_memory += scale.numel() + scale.element_size() + quantized_param_memory += zero_point.numel() + zero_point.element_size() + else: + unquantized_param_memory += param.data.numel() * param.data.element_size() + + total_memory = quantized_param_memory + unquantized_param_memory + return total_memory, quantized_param_memory, unquantized_param_memory + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"] + transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"] + transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] + transformer_bf16 = self.get_dummy_components(None)["transformer"] + + total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo) + total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint( + transformer_int4wo_gs32 + ) + total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo) + total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16) + + self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16) + # int4wo_gs32 has smaller group size, so more groups -> more scales and zero points + self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32) + # int4 with default group size quantized very few linear layers compared to a smaller group size of 32 + self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32) + # int8 quantizes more layers compare to int4 with default group size + self.assertTrue(quantized_int8wo < quantized_int4wo) + + def test_wrong_config(self): + with self.assertRaises(ValueError): + self.get_dummy_components(TorchAoConfig("int42")) + + +# This class is not to be run as a test by itself. See the tests that follow this class +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +class TorchAoSerializationTest(unittest.TestCase): + model_name = "hf-internal-testing/tiny-flux-pipe" + quant_method, quant_method_kwargs = None, None + device = "cuda" + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_model(self, device=None): + quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs) + quantized_model = FluxTransformer2DModel.from_pretrained( + self.model_name, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + return quantized_model.to(device) + + def get_dummy_tensor_inputs(self, device=None, seed: int = 0): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + torch.manual_seed(seed) + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def test_original_model_expected_slice(self): + quantized_model = self.get_dummy_model(torch_device) + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3)) + + def check_serialization_expected_slice(self, expected_slice): + quantized_model = self.get_dummy_model(self.device) + + with tempfile.TemporaryDirectory() as tmp_dir: + quantized_model.save_pretrained(tmp_dir, safe_serialization=False) + loaded_quantized_model = FluxTransformer2DModel.from_pretrained( + tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False + ) + + inputs = self.get_dummy_tensor_inputs(torch_device) + output = loaded_quantized_model(**inputs)[0] + + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue( + isinstance( + loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor) + ) + ) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_serialization_expected_slice(self): + self.check_serialization_expected_slice(self.serialized_expected_slice) + + +class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + serialized_expected_slice = expected_slice + device = "cuda" + + +class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + serialized_expected_slice = expected_slice + device = "cuda" + + +class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + serialized_expected_slice = expected_slice + device = "cpu" + + +class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest): + quant_method, quant_method_kwargs = "int8_weight_only", {} + expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) + serialized_expected_slice = expected_slice + device = "cpu" + + +# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners +@require_torch +@require_torch_gpu +@require_torchao_version_greater("0.6.0") +@slow +@nightly +class SlowTorchAoTests(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_components(self, quantization_config: TorchAoConfig): + model_id = "black-forest-labs/FLUX.1-dev" + transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ) + text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") + text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 512, + "width": 512, + "num_inference_steps": 20, + "output_type": "np", + "generator": generator, + } + + return inputs + + def _test_quant_type(self, quantization_config, expected_slice): + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components).to(dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten() + output_slice = np.concatenate((output[:16], output[-16:])) + + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + def test_quantization(self): + # fmt: off + QUANTIZATION_TYPES_TO_TEST = [ + ("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])), + ("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])), + ] + + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), + ("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])), + ]) + # fmt: on + + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) + self._test_quant_type(quantization_config, expected_slice) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize()