Skip to content

Commit 7c2f0af

Browse files
authored
update get_parameter_dtype (#10342)
add: q
1 parent f615f00 commit 7c2f0af

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,39 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
9999

100100

101101
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
102-
try:
103-
return next(parameter.parameters()).dtype
104-
except StopIteration:
105-
try:
106-
return next(parameter.buffers()).dtype
107-
except StopIteration:
108-
# For torch.nn.DataParallel compatibility in PyTorch 1.5
109-
110-
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
111-
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
112-
return tuples
113-
114-
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
115-
first_tuple = next(gen)
116-
return first_tuple[1].dtype
102+
"""
103+
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
104+
"""
105+
last_dtype = None
106+
for param in parameter.parameters():
107+
last_dtype = param.dtype
108+
if param.is_floating_point():
109+
return param.dtype
110+
111+
for buffer in parameter.buffers():
112+
last_dtype = buffer.dtype
113+
if buffer.is_floating_point():
114+
return buffer.dtype
115+
116+
if last_dtype is not None:
117+
# if no floating dtype was found return whatever the first dtype is
118+
return last_dtype
119+
120+
# For nn.DataParallel compatibility in PyTorch > 1.5
121+
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
122+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
123+
return tuples
124+
125+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
126+
last_tuple = None
127+
for tuple in gen:
128+
last_tuple = tuple
129+
if tuple[1].is_floating_point():
130+
return tuple[1].dtype
131+
132+
if last_tuple is not None:
133+
# fallback to the last dtype
134+
return last_tuple[1].dtype
117135

118136

119137
class ModelMixin(torch.nn.Module, PushToHubMixin):

0 commit comments

Comments
 (0)