diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7c371deb7c9..a737fec3751 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -131,7 +131,7 @@ def load_lora(name, filename): with torch.no_grad(): module.weight.copy_(weight) - module.to(device=devices.device, dtype=devices.dtype) + module.to(device=devices.cpu, dtype=devices.dtype) if lora_key == "lora_up.weight": lora_module.up = module @@ -177,29 +177,69 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_forward(module, input, res): - input = devices.cond_cast_unet(input) - if len(loaded_loras) == 0: - return res +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): + """ + Applies the currently selected set of Loras to the weight of torch layer self. + If weights already have this particular set of loras applied, does nothing. + If not, restores orginal weights from backup and alters weights according to loras. + """ - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None: - if shared.opts.lora_apply_to_outputs and res.shape == input.shape: - res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - else: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + current_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) + + weights_backup = getattr(self, "lora_weights_backup", None) + if weights_backup is None: + weights_backup = self.weight.to(devices.cpu, copy=True) + self.lora_weights_backup = weights_backup + + if current_names != wanted_names: + if weights_backup is not None: + self.weight.copy_(weights_backup) + + lora_layer_name = getattr(self, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is None: + continue - return res + with torch.no_grad(): + up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype) + down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + setattr(self, "lora_current_names", wanted_names) def lora_Linear_forward(self, input): - return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Linear_forward_before_lora(self, input) + + +def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) def lora_Conv2d_forward(self, input): - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lora(self, input) + + +def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160e94..dc329e81237 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -9,7 +9,9 @@ def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora def before_ui(): @@ -20,11 +22,19 @@ def before_ui(): if not hasattr(torch.nn, 'Linear_forward_before_lora'): torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): + torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict + if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): + torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -33,6 +43,4 @@ def before_ui(): shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), - "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), - }))