@@ -99,21 +99,39 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
99
99
100
100
101
101
def get_parameter_dtype (parameter : torch .nn .Module ) -> torch .dtype :
102
- try :
103
- return next (parameter .parameters ()).dtype
104
- except StopIteration :
105
- try :
106
- return next (parameter .buffers ()).dtype
107
- except StopIteration :
108
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
109
-
110
- def find_tensor_attributes (module : torch .nn .Module ) -> List [Tuple [str , Tensor ]]:
111
- tuples = [(k , v ) for k , v in module .__dict__ .items () if torch .is_tensor (v )]
112
- return tuples
113
-
114
- gen = parameter ._named_members (get_members_fn = find_tensor_attributes )
115
- first_tuple = next (gen )
116
- return first_tuple [1 ].dtype
102
+ """
103
+ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
104
+ """
105
+ last_dtype = None
106
+ for param in parameter .parameters ():
107
+ last_dtype = param .dtype
108
+ if param .is_floating_point ():
109
+ return param .dtype
110
+
111
+ for buffer in parameter .buffers ():
112
+ last_dtype = buffer .dtype
113
+ if buffer .is_floating_point ():
114
+ return buffer .dtype
115
+
116
+ if last_dtype is not None :
117
+ # if no floating dtype was found return whatever the first dtype is
118
+ return last_dtype
119
+
120
+ # For nn.DataParallel compatibility in PyTorch > 1.5
121
+ def find_tensor_attributes (module : nn .Module ) -> List [Tuple [str , Tensor ]]:
122
+ tuples = [(k , v ) for k , v in module .__dict__ .items () if torch .is_tensor (v )]
123
+ return tuples
124
+
125
+ gen = parameter ._named_members (get_members_fn = find_tensor_attributes )
126
+ last_tuple = None
127
+ for tuple in gen :
128
+ last_tuple = tuple
129
+ if tuple [1 ].is_floating_point ():
130
+ return tuple [1 ].dtype
131
+
132
+ if last_tuple is not None :
133
+ # fallback to the last dtype
134
+ return last_tuple [1 ].dtype
117
135
118
136
119
137
class ModelMixin (torch .nn .Module , PushToHubMixin ):
0 commit comments