Skip to content

Commit 03049aa

Browse files
committed
Revert "Add support for sharded models when TorchAO quantization is enabled (#10256)"
This reverts commit 41ba8c0.
1 parent c1e7fd5 commit 03049aa

File tree

2 files changed

+24
-48
lines changed

2 files changed

+24
-48
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
820820
revision=revision,
821821
subfolder=subfolder or "",
822822
)
823-
if hf_quantizer is not None and is_bnb_quantization_method:
823+
if hf_quantizer is not None:
824824
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
825825
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
826826
is_sharded = False

tests/quantization/torchao/test_torchao.py

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,13 @@ def test_int4wo_quant_bfloat16_conversion(self):
279279
self.assertEqual(weight.quant_min, 0)
280280
self.assertEqual(weight.quant_max, 15)
281281

282-
def test_device_map(self):
282+
def test_offload(self):
283283
"""
284-
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
285-
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
286-
correctly set (in the `hf_device_map` attribute of the model).
284+
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
285+
that the device map is correctly set (in the `hf_device_map` attribute of the model).
287286
"""
288287

289-
custom_device_map_dict = {
288+
device_map_offload = {
290289
"time_text_embed": torch_device,
291290
"context_embedder": torch_device,
292291
"x_embedder": torch_device,
@@ -295,50 +294,27 @@ def test_device_map(self):
295294
"norm_out": torch_device,
296295
"proj_out": "cpu",
297296
}
298-
device_maps = ["auto", custom_device_map_dict]
299297

300298
inputs = self.get_dummy_tensor_inputs(torch_device)
301-
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
302-
303-
for device_map in device_maps:
304-
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
305-
306-
# Test non-sharded model
307-
with tempfile.TemporaryDirectory() as offload_folder:
308-
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
309-
quantized_model = FluxTransformer2DModel.from_pretrained(
310-
"hf-internal-testing/tiny-flux-pipe",
311-
subfolder="transformer",
312-
quantization_config=quantization_config,
313-
device_map=device_map,
314-
torch_dtype=torch.bfloat16,
315-
offload_folder=offload_folder,
316-
)
317-
318-
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
319-
320-
output = quantized_model(**inputs)[0]
321-
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
322-
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
323-
324-
# Test sharded model
325-
with tempfile.TemporaryDirectory() as offload_folder:
326-
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
327-
quantized_model = FluxTransformer2DModel.from_pretrained(
328-
"hf-internal-testing/tiny-flux-sharded",
329-
subfolder="transformer",
330-
quantization_config=quantization_config,
331-
device_map=device_map,
332-
torch_dtype=torch.bfloat16,
333-
offload_folder=offload_folder,
334-
)
335-
336-
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
337-
338-
output = quantized_model(**inputs)[0]
339-
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
340-
341-
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
299+
300+
with tempfile.TemporaryDirectory() as offload_folder:
301+
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
302+
quantized_model = FluxTransformer2DModel.from_pretrained(
303+
"hf-internal-testing/tiny-flux-pipe",
304+
subfolder="transformer",
305+
quantization_config=quantization_config,
306+
device_map=device_map_offload,
307+
torch_dtype=torch.bfloat16,
308+
offload_folder=offload_folder,
309+
)
310+
311+
self.assertTrue(quantized_model.hf_device_map == device_map_offload)
312+
313+
output = quantized_model(**inputs)[0]
314+
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
315+
316+
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
317+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
342318

343319
def test_modules_to_not_convert(self):
344320
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])

0 commit comments

Comments
 (0)