Skip to content

Commit 2ed4ff2

Browse files
authored
update flax scheduler API (huggingface#822)
* update flax scheduler API * remoev set format * fix call to scale_model_input * update flax pndm * use int32 * update docstr
1 parent ad8c9a6 commit 2ed4ff2

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def loop_body(step, args):
170170
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
171171
timestep = jnp.broadcast_to(t, latents_input.shape[0])
172172

173+
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
174+
173175
# predict the noise residual
174176
noise_pred = self.unet.apply(
175177
{"params": params["unet"]},
@@ -189,6 +191,9 @@ def loop_body(step, args):
189191
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
190192
)
191193

194+
# scale the initial noise by the standard deviation required by the scheduler
195+
latents = latents * self.scheduler.init_noise_sigma
196+
192197
if debug:
193198
# run with python for loop
194199
for i in range(num_inference_steps):

schedulers/scheduling_ddim_flax.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,23 @@ def __init__(
141141
# whether we use the final alpha of the "non-previous" one.
142142
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
143143

144+
# standard deviation of the initial noise distribution
145+
self.init_noise_sigma = 1.0
146+
147+
def scale_model_input(
148+
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
149+
) -> jnp.ndarray:
150+
"""
151+
Args:
152+
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
153+
sample (`jnp.ndarray`): input sample
154+
timestep (`int`, optional): current timestep
155+
156+
Returns:
157+
`jnp.ndarray`: scaled input sample
158+
"""
159+
return sample
160+
144161
def create_state(self):
145162
return DDIMSchedulerState.create(
146163
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod

schedulers/scheduling_pndm_flax.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ def __init__(
153153
# mainly at formula (9), (12), (13) and the Algorithm 2.
154154
self.pndm_order = 4
155155

156+
# standard deviation of the initial noise distribution
157+
self.init_noise_sigma = 1.0
158+
156159
def create_state(self):
157160
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
158161

@@ -196,14 +199,31 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha
196199
)
197200

198201
return state.replace(
199-
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
202+
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
200203
counter=0,
201204
# Reserve space for the state variables
202205
cur_model_output=jnp.zeros(shape),
203206
cur_sample=jnp.zeros(shape),
204207
ets=jnp.zeros((4,) + shape),
205208
)
206209

210+
def scale_model_input(
211+
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
212+
) -> jnp.ndarray:
213+
"""
214+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
215+
current timestep.
216+
217+
Args:
218+
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
219+
sample (`jnp.ndarray`): input sample
220+
timestep (`int`, optional): current timestep
221+
222+
Returns:
223+
`jnp.ndarray`: scaled input sample
224+
"""
225+
return sample
226+
207227
def step(
208228
self,
209229
state: PNDMSchedulerState,

0 commit comments

Comments
 (0)