@@ -279,14 +279,13 @@ def test_int4wo_quant_bfloat16_conversion(self):
279
279
self .assertEqual (weight .quant_min , 0 )
280
280
self .assertEqual (weight .quant_max , 15 )
281
281
282
- def test_device_map (self ):
282
+ def test_offload (self ):
283
283
"""
284
- Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
285
- The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
286
- correctly set (in the `hf_device_map` attribute of the model).
284
+ Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
285
+ that the device map is correctly set (in the `hf_device_map` attribute of the model).
287
286
"""
288
287
289
- custom_device_map_dict = {
288
+ device_map_offload = {
290
289
"time_text_embed" : torch_device ,
291
290
"context_embedder" : torch_device ,
292
291
"x_embedder" : torch_device ,
@@ -295,50 +294,27 @@ def test_device_map(self):
295
294
"norm_out" : torch_device ,
296
295
"proj_out" : "cpu" ,
297
296
}
298
- device_maps = ["auto" , custom_device_map_dict ]
299
297
300
298
inputs = self .get_dummy_tensor_inputs (torch_device )
301
- expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
302
-
303
- for device_map in device_maps :
304
- device_map_to_compare = {"" : 0 } if device_map == "auto" else device_map
305
-
306
- # Test non-sharded model
307
- with tempfile .TemporaryDirectory () as offload_folder :
308
- quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
309
- quantized_model = FluxTransformer2DModel .from_pretrained (
310
- "hf-internal-testing/tiny-flux-pipe" ,
311
- subfolder = "transformer" ,
312
- quantization_config = quantization_config ,
313
- device_map = device_map ,
314
- torch_dtype = torch .bfloat16 ,
315
- offload_folder = offload_folder ,
316
- )
317
-
318
- self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
319
-
320
- output = quantized_model (** inputs )[0 ]
321
- output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
322
- self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
323
-
324
- # Test sharded model
325
- with tempfile .TemporaryDirectory () as offload_folder :
326
- quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
327
- quantized_model = FluxTransformer2DModel .from_pretrained (
328
- "hf-internal-testing/tiny-flux-sharded" ,
329
- subfolder = "transformer" ,
330
- quantization_config = quantization_config ,
331
- device_map = device_map ,
332
- torch_dtype = torch .bfloat16 ,
333
- offload_folder = offload_folder ,
334
- )
335
-
336
- self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
337
-
338
- output = quantized_model (** inputs )[0 ]
339
- output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
340
-
341
- self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
299
+
300
+ with tempfile .TemporaryDirectory () as offload_folder :
301
+ quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
302
+ quantized_model = FluxTransformer2DModel .from_pretrained (
303
+ "hf-internal-testing/tiny-flux-pipe" ,
304
+ subfolder = "transformer" ,
305
+ quantization_config = quantization_config ,
306
+ device_map = device_map_offload ,
307
+ torch_dtype = torch .bfloat16 ,
308
+ offload_folder = offload_folder ,
309
+ )
310
+
311
+ self .assertTrue (quantized_model .hf_device_map == device_map_offload )
312
+
313
+ output = quantized_model (** inputs )[0 ]
314
+ output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
315
+
316
+ expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
317
+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
342
318
343
319
def test_modules_to_not_convert (self ):
344
320
quantization_config = TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
0 commit comments