Skip to content

Commit ad31600

Browse files
authored
sd3推理优化——避免同步 (#695)
when s_churn == 0.0,not need to compute gamma, Can avoid cuda synchronization;可以加速SD3端到端性能。
1 parent f394358 commit ad31600

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ..utils.paddle_utils import randn_tensor
2525
from .scheduling_utils import SchedulerMixin
2626

27-
2827
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2928

3029

@@ -245,12 +244,13 @@ def step(
245244
sample = sample.cast(paddle.float32)
246245

247246
sigma = self.sigmas[self.step_index]
247+
# NOTE:(changwenbin & zhoukangkang) when s_churn == 0.0,not need to compute gamma, Can avoid cuda synchronization
248+
if s_churn == 0.0:
249+
gamma = 0.0
250+
else:
251+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
248252

249-
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
250-
251-
noise = randn_tensor(
252-
model_output.shape, dtype=model_output.dtype, generator=generator
253-
)
253+
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, generator=generator)
254254

255255
eps = noise * s_noise
256256
sigma_hat = sigma * (gamma + 1)
@@ -283,4 +283,4 @@ def step(
283283
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
284284

285285
def __len__(self):
286-
return self.config.num_train_timesteps
286+
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)