Skip to content

Commit e2cd893

Browse files
DN6sayakpaul
authored andcommitted
[Single File] Allow loading T5 encoder in mixed precision (#8778)
* update * update * update * update
1 parent fe85e9c commit e2cd893

8 files changed

+85
-3
lines changed

src/diffusers/loaders/single_file.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,4 @@ def load_module(name, value):
555555

556556
pipe = pipeline_class(**init_kwargs)
557557

558-
if torch_dtype is not None:
559-
pipe.to(dtype=torch_dtype)
560-
561558
return pipe

src/diffusers/loaders/single_file_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,4 +1808,17 @@ def create_diffusers_t5_model_from_checkpoint(
18081808

18091809
else:
18101810
model.load_state_dict(diffusers_format_checkpoint)
1811+
1812+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
1813+
if use_keep_in_fp32_modules:
1814+
keep_in_fp32_modules = model._keep_in_fp32_modules
1815+
else:
1816+
keep_in_fp32_modules = []
1817+
1818+
if keep_in_fp32_modules is not None:
1819+
for name, param in model.named_parameters():
1820+
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
1821+
# param = param.to(torch.float32) does not work here as only in the local scope.
1822+
param.data = param.data.to(torch.float32)
1823+
18111824
return model

tests/single_file/single_file_testing_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def test_single_file_components_with_diffusers_config_local_files_only(
201201

202202
self._compare_component_configs(pipe, single_file_pipe)
203203

204+
def test_single_file_setting_pipeline_dtype_to_fp16(
205+
self,
206+
single_file_pipe=None,
207+
):
208+
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
209+
self.ckpt_path, torch_dtype=torch.float16
210+
)
211+
212+
for component_name, component in single_file_pipe.components.items():
213+
if not isinstance(component, torch.nn.Module):
214+
continue
215+
216+
assert component.dtype == torch.float16
217+
204218

205219
class SDXLSingleFileTesterMixin:
206220
def _compare_component_configs(self, pipe, single_file_pipe):
@@ -378,3 +392,17 @@ def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_d
378392
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
379393

380394
assert max_diff < expected_max_diff
395+
396+
def test_single_file_setting_pipeline_dtype_to_fp16(
397+
self,
398+
single_file_pipe=None,
399+
):
400+
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
401+
self.ckpt_path, torch_dtype=torch.float16
402+
)
403+
404+
for component_name, component in single_file_pipe.components.items():
405+
if not isinstance(component, torch.nn.Module):
406+
continue
407+
408+
assert component.dtype == torch.float16

tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self):
180180
local_files_only=True,
181181
)
182182
super()._compare_component_configs(pipe, pipe_single_file)
183+
184+
def test_single_file_setting_pipeline_dtype_to_fp16(self):
185+
controlnet = ControlNetModel.from_pretrained(
186+
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
187+
)
188+
single_file_pipe = self.pipeline_class.from_single_file(
189+
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
190+
)
191+
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self):
181181
local_files_only=True,
182182
)
183183
super()._compare_component_configs(pipe, pipe_single_file)
184+
185+
def test_single_file_setting_pipeline_dtype_to_fp16(self):
186+
controlnet = ControlNetModel.from_pretrained(
187+
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
188+
)
189+
single_file_pipe = self.pipeline_class.from_single_file(
190+
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
191+
)
192+
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

tests/single_file/test_stable_diffusion_controlnet_single_file.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self):
169169
local_files_only=True,
170170
)
171171
super()._compare_component_configs(pipe, pipe_single_file)
172+
173+
def test_single_file_setting_pipeline_dtype_to_fp16(self):
174+
controlnet = ControlNetModel.from_pretrained(
175+
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
176+
)
177+
single_file_pipe = self.pipeline_class.from_single_file(
178+
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
179+
)
180+
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

tests/single_file/test_stable_diffusion_xl_adapter_single_file.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,11 @@ def test_single_file_components_with_original_config_local_files_only(self):
200200
local_files_only=True,
201201
)
202202
self._compare_component_configs(pipe, pipe_single_file)
203+
204+
def test_single_file_setting_pipeline_dtype_to_fp16(self):
205+
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
206+
207+
single_file_pipe = self.pipeline_class.from_single_file(
208+
self.ckpt_path, adapter=adapter, torch_dtype=torch.float16
209+
)
210+
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self):
195195
local_files_only=True,
196196
)
197197
super()._compare_component_configs(pipe, pipe_single_file)
198+
199+
def test_single_file_setting_pipeline_dtype_to_fp16(self):
200+
controlnet = ControlNetModel.from_pretrained(
201+
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
202+
)
203+
single_file_pipe = self.pipeline_class.from_single_file(
204+
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
205+
)
206+
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

0 commit comments

Comments
 (0)