Skip to content

Commit

Permalink
apply Lora by altering layer's weights instead of adding more calcula…
Browse files Browse the repository at this point in the history
…tions in forward()
  • Loading branch information
AUTOMATIC1111 committed Mar 25, 2023
1 parent 69eb2a9 commit 80b26d2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 18 deletions.
72 changes: 56 additions & 16 deletions extensions-builtin/Lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def load_lora(name, filename):
with torch.no_grad():
module.weight.copy_(weight)

module.to(device=devices.device, dtype=devices.dtype)
module.to(device=devices.cpu, dtype=devices.dtype)

if lora_key == "lora_up.weight":
lora_module.up = module
Expand Down Expand Up @@ -177,29 +177,69 @@ def load_loras(names, multipliers=None):
loaded_loras.append(lora)


def lora_forward(module, input, res):
input = devices.cond_cast_unet(input)
if len(loaded_loras) == 0:
return res
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
"""
Applies the currently selected set of Loras to the weight of torch layer self.
If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras.
"""

lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
else:
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
current_names = getattr(self, "lora_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)

weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None:
weights_backup = self.weight.to(devices.cpu, copy=True)
self.lora_weights_backup = weights_backup

if current_names != wanted_names:
if weights_backup is not None:
self.weight.copy_(weights_backup)

lora_layer_name = getattr(self, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is None:
continue

return res
with torch.no_grad():
up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
updown = up @ down

self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)

setattr(self, "lora_current_names", wanted_names)


def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
lora_apply_weights(self)

return torch.nn.Linear_forward_before_lora(self, input)


def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)

return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)


def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
lora_apply_weights(self)

return torch.nn.Conv2d_forward_before_lora(self, input)


def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)

return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)


def list_available_loras():
Expand Down
12 changes: 10 additions & 2 deletions extensions-builtin/Lora/scripts/lora_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora


def before_ui():
Expand All @@ -20,11 +22,19 @@ def before_ui():
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward

if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict

if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward

if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict

torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict

script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
Expand All @@ -33,6 +43,4 @@ def before_ui():

shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),

}))

1 comment on commit 80b26d2

@ClashSAN
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change seems to break LoRA reproducability #9207 (comment), for img2img mode (highres fix)

entry added here for now: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Seed-breaking-changes

Please sign in to comment.