From 32cabdf8f46d0d8f0ff3cfd52b9924d8cc45e2b9 Mon Sep 17 00:00:00 2001 From: NotNANtoN Date: Fri, 18 Nov 2022 16:01:57 +0100 Subject: [PATCH] Fix img2img speed with LMS-Discrete Scheduler (#896) Casting `self.sigmas` into a different dtype (the one of original_samples) is not advisable. In my img2img pipeline this leads to a long running time in the `integrate.quad` call later on- by long I mean more than 10x slower. Co-authored-by: Anton Lozhkov --- schedulers/scheduling_lms_discrete.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/schedulers/scheduling_lms_discrete.py b/schedulers/scheduling_lms_discrete.py index 8a9aedb41bf1..cc9e8d72566a 100644 --- a/schedulers/scheduling_lms_discrete.py +++ b/schedulers/scheduling_lms_discrete.py @@ -243,19 +243,18 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - schedule_timesteps = self.timesteps step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1)