Skip to content

Commit

Permalink
Fix img2img speed with LMS-Discrete Scheduler (#896)
Browse files Browse the repository at this point in the history
Casting `self.sigmas` into a different dtype (the one of original_samples) is not advisable. In my img2img pipeline this leads to a long running time in the  `integrate.quad` call later on- by long I mean more than 10x slower.

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
  • Loading branch information
NotNANtoN and anton-l authored Nov 18, 2022
1 parent 81fa2d6 commit aa2ce41
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,18 @@ def add_noise(
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

schedule_timesteps = self.timesteps
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = self.sigmas[step_indices].flatten()
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

Expand Down

0 comments on commit aa2ce41

Please sign in to comment.