Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MPRNet #210

Merged
merged 2 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar
- [Restormer](https://github.com/swz30/Restormer) (+) | [Models](https://github.com/swz30/Restormer/releases/tag/v1.0)
- [FFTformer](https://github.com/kkkls/FFTformer) | [Models](https://github.com/kkkls/FFTformer/releases/tag/pretrain_model)
- [M3SNet](https://github.com/Tombs98/M3SNet) (+) | [Models](https://drive.google.com/drive/folders/1y4BEX7LagtXVO98ZItSbJJl7WWM3gnbD)
- [MPRNet](https://github.com/swz30/MPRNet) (+) | [Deblurring](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing), [Deraining](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing), [Denoising](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing)

#### DeJPEG

Expand Down
2 changes: 2 additions & 0 deletions libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DDColor,
FeMaSR,
M3SNet,
MPRNet,
Restormer,
SRFormer,
)
Expand All @@ -22,4 +23,5 @@
ArchSupport.from_architecture(FeMaSR.FeMaSRArch()),
ArchSupport.from_architecture(M3SNet.M3SNetArch()),
ArchSupport.from_architecture(Restormer.RestormerArch()),
ArchSupport.from_architecture(MPRNet.MPRNetArch()),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

from typing_extensions import override

from spandrel import (
Architecture,
ImageModelDescriptor,
SizeRequirements,
StateDict,
)
from spandrel.util import KeyCondition, get_seq_len

from .arch.MPRNet import MPRNet


class MPRNetArch(Architecture[MPRNet]):
def __init__(self) -> None:
super().__init__(
id="MPRNet",
detect=KeyCondition.has_all(
"shallow_feat1.0.weight",
"shallow_feat1.1.CA.conv_du.0.weight",
"shallow_feat1.1.CA.conv_du.2.weight",
"shallow_feat1.1.body.0.weight",
"shallow_feat1.1.body.2.weight",
"shallow_feat2.0.weight",
"shallow_feat3.0.weight",
"stage1_encoder.encoder_level1.0.CA.conv_du.0.weight",
"stage1_encoder.encoder_level1.0.CA.conv_du.2.weight",
"stage1_encoder.encoder_level1.0.body.2.weight",
"stage1_encoder.encoder_level1.1.body.2.weight",
"stage1_encoder.encoder_level2.1.body.2.weight",
"stage1_encoder.encoder_level3.0.CA.conv_du.0.weight",
"stage1_decoder.decoder_level1.0.CA.conv_du.0.weight",
"stage1_decoder.decoder_level1.0.body.0.weight",
"stage1_decoder.decoder_level2.0.CA.conv_du.0.weight",
"stage1_decoder.decoder_level3.0.CA.conv_du.0.weight",
"stage1_decoder.skip_attn1.CA.conv_du.0.weight",
"stage1_decoder.skip_attn2.CA.conv_du.0.weight",
"stage1_decoder.up32.up.1.weight",
"stage2_encoder.encoder_level1.0.CA.conv_du.0.weight",
"stage2_decoder.decoder_level1.0.CA.conv_du.0.weight",
"sam12.conv1.weight",
"sam12.conv3.weight",
"sam23.conv3.weight",
"concat12.weight",
"concat23.weight",
"tail.weight",
"stage3_orsnet.orb1.body.0.CA.conv_du.0.weight",
"stage3_orsnet.orb1.body.0.CA.conv_du.2.weight",
"stage3_orsnet.orb1.body.0.body.0.weight",
"stage3_orsnet.orb1.body.0.body.2.weight",
),
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[MPRNet]:
# in_c: int = 3
# out_c: int = 3
# n_feat: int = 40
# scale_unetfeats: int = 20
# scale_orsnetfeats: int = 16
# num_cab: int = 8
# kernel_size: int = 3
# reduction = 4
# bias = False

in_c = state_dict["shallow_feat1.0.weight"].shape[1]
n_feat = state_dict["shallow_feat1.0.weight"].shape[0]
kernel_size = state_dict["shallow_feat1.0.weight"].shape[2]
bias = "shallow_feat1.0.bias" in state_dict
reduction = n_feat // state_dict["shallow_feat1.1.CA.conv_du.0.weight"].shape[0]

out_c = state_dict["tail.weight"].shape[0]
scale_orsnetfeats = state_dict["tail.weight"].shape[1] - n_feat
scale_unetfeats = (
state_dict["stage1_encoder.encoder_level2.0.CA.conv_du.0.weight"].shape[1]
- n_feat
)

num_cab = get_seq_len(state_dict, "stage3_orsnet.orb1.body") - 1

model = MPRNet(
in_c=in_c,
out_c=out_c,
n_feat=n_feat,
scale_unetfeats=scale_unetfeats,
scale_orsnetfeats=scale_orsnetfeats,
num_cab=num_cab,
kernel_size=kernel_size,
reduction=reduction,
bias=bias,
)

return ImageModelDescriptor(
model,
state_dict,
architecture=self,
purpose="Restoration",
tags=[f"{n_feat}nf"],
supports_half=False, # TODO: verify
supports_bfloat16=True,
scale=1,
input_channels=in_c,
output_channels=out_c,
size_requirements=SizeRequirements(multiple_of=8),
call_fn=lambda model, x: model(x)[0],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## ACADEMIC PUBLIC LICENSE

### Permissions
:heavy_check_mark: Non-Commercial use
:heavy_check_mark: Modification
:heavy_check_mark: Distribution
:heavy_check_mark: Private use

### Limitations
:x: Commercial Use
:x: Liability
:x: Warranty

### Conditions
:information_source: License and copyright notice
:information_source: Same License

MPRNet is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations.
You can use MPRNet in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately.

You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software.
If you distribute verbatim or modified copies of this software, they must be distributed under this license.
This license guarantees that you're safe when using MPRNet in your work, for teaching or research.
This license guarantees that MPRNet will remain available free of charge for nonprofit use.
You can modify MPRNet to your purposes, and you can also share your modifications.

If you would like to use MPRNet in commercial settings, contact us so we can discuss options. Send an email to waqas.zamir@inceptioniai.org


Loading
Loading