Skip to content

Commit a9fdb3d

Browse files
authored
Return Flax scheduler state (#601)
* Optionally return state in from_config. Useful for Flax schedulers. * has_state is now a property, make check more strict. I don't check the class is `SchedulerMixin` to prevent circular dependencies. It should be enough that the class name starts with "Flax" the object declares it "has_state" and the "create_state" exists too. * Use state in pipeline from_pretrained. * Make style
1 parent e72f1a8 commit a9fdb3d

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

src/diffusers/configuration_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,19 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret
160160
if "dtype" in unused_kwargs:
161161
init_dict["dtype"] = unused_kwargs.pop("dtype")
162162

163+
# Return model and optionally state and/or unused_kwargs
163164
model = cls(**init_dict)
165+
return_tuple = (model,)
166+
167+
# Flax schedulers have a state, so return it.
168+
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
169+
state = model.create_state()
170+
return_tuple += (state,)
164171

165172
if return_unused_kwargs:
166-
return model, unused_kwargs
173+
return return_tuple + (unused_kwargs,)
167174
else:
168-
return model
175+
return return_tuple if len(return_tuple) > 1 else model
169176

170177
@classmethod
171178
def get_config_dict(

src/diffusers/pipeline_flax_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
437437
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
438438
params[name] = loaded_params
439439
elif issubclass(class_obj, SchedulerMixin):
440-
loaded_sub_model = load_method(loadable_folder)
441-
params[name] = loaded_sub_model.create_state()
440+
loaded_sub_model, scheduler_state = load_method(loadable_folder)
441+
params[name] = scheduler_state
442442
else:
443443
loaded_sub_model = load_method(loadable_folder)
444444

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
105105
stable diffusion.
106106
"""
107107

108+
@property
109+
def has_state(self):
110+
return True
111+
108112
@register_to_config
109113
def __init__(
110114
self,

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
113113
stable diffusion.
114114
"""
115115

116+
@property
117+
def has_state(self):
118+
return True
119+
116120
@register_to_config
117121
def __init__(
118122
self,

0 commit comments

Comments
 (0)