Skip to content

Commit

Permalink
Rename GRLIR -> GRL (#120)
Browse files Browse the repository at this point in the history
* Rename GRLIR -> GRL

* simplify test
  • Loading branch information
RunDevelopment committed Jan 11, 2024
1 parent 1b560ce commit 68d43fd
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 68 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar
- [SRFormer](https://github.com/HVision-NKU/SRFormer) | [Models](https://github.com/HVision-NKU/SRFormer#pretrain-models)
- [DAT](https://github.com/zhengchen1999/DAT) | [Models](https://github.com/zhengchen1999/DAT#testing)
- [FeMaSR](https://github.com/chaofengc/FeMaSR) | [Models](https://github.com/chaofengc/FeMaSR/releases/tag/v0.1-pretrain_models)
- [GRLIR](https://github.com/ofsoundof/GRL-Image-Restoration) | [Models](https://github.com/ofsoundof/GRL-Image-Restoration/releases/tag/v1.0.0)
- [GRL](https://github.com/ofsoundof/GRL-Image-Restoration) | [Models](https://github.com/ofsoundof/GRL-Image-Restoration/releases/tag/v1.0.0)
- [DITN](https://github.com/yongliuy/DITN) | [Models](https://drive.google.com/drive/folders/1XpHW27H5j2S4IH8t4lccgrgHkIjqrS-X)
- [MM-RealSR](https://github.com/TencentARC/MM-RealSR) | [Models](https://github.com/TencentARC/MM-RealSR/releases/tag/v1.0.0)
- [SPAN](https://github.com/hongyuanyu/SPAN) | [Models](https://drive.google.com/file/d/1iYUA2TzKuxI0vzmA-UXr_nB43XgPOXUg/view?usp=sharing)
Expand Down
6 changes: 3 additions & 3 deletions src/spandrel/__helpers/main_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ESRGAN,
FBCNN,
GFPGAN,
GRLIR,
GRL,
HAT,
MAT,
SPAN,
Expand Down Expand Up @@ -91,7 +91,7 @@ def _detect(state_dict: StateDict) -> bool:
load=HAT.load,
),
ArchSupport(
id="GRLIR",
id="GRL",
detect=lambda state: (
_has_keys(
"conv_first.weight",
Expand All @@ -115,7 +115,7 @@ def _detect(state_dict: StateDict) -> bool:
"model_g.layers.0.blocks.0.attn.stripe_attn.attn_transform1.logit_scale",
)(state)
),
load=GRLIR.load,
load=GRL.load,
),
ArchSupport(
id="Swin2SR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...__helpers.canonicalize import remove_common_prefix
from ...__helpers.model_descriptor import ImageModelDescriptor, StateDict
from ..__arch_helpers.state import get_scale_and_output_channels, get_seq_len
from .arch.grl import GRL as GRLIR
from .arch.grl import GRL

_NON_PERSISTENT_BUFFERS = [
"table_w",
Expand Down Expand Up @@ -132,7 +132,7 @@ def _inv_div_add(a: int, d: int) -> int:
return round(a / (1 + 1 / d))


def load(state_dict: StateDict) -> ImageModelDescriptor[GRLIR]:
def load(state_dict: StateDict) -> ImageModelDescriptor[GRL]:
state_dict = _clean_up_checkpoint(state_dict)

img_size: int = 64
Expand Down Expand Up @@ -266,7 +266,7 @@ def load(state_dict: StateDict) -> ImageModelDescriptor[GRLIR]:
if buffer_key in state_dict:
del state_dict[buffer_key]

model = GRLIR(
model = GRL(
img_size=img_size,
in_channels=in_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -305,7 +305,7 @@ def load(state_dict: StateDict) -> ImageModelDescriptor[GRLIR]:
return ImageModelDescriptor(
model,
state_dict,
architecture="GRLIR",
architecture="GRL",
purpose="Restoration" if upscale == 1 else "SR",
tags=[
size_tag,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# serializer version: 1
# name: test_GRLIR_bsr_grl_base
# name: test_GRL_bsr_grl_base
ImageModelDescriptor(
architecture='GRLIR',
architecture='GRL',
input_channels=3,
output_channels=3,
purpose='SR',
Expand All @@ -18,9 +18,9 @@
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
# name: test_GRLIR_sr_grl_tiny_c3x3
# name: test_GRL_sr_grl_tiny_c3x3
ImageModelDescriptor(
architecture='GRLIR',
architecture='GRL',
input_channels=3,
output_channels=3,
purpose='SR',
Expand All @@ -37,9 +37,9 @@
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
# name: test_GRLIR_sr_grl_tiny_c3x4
# name: test_GRL_sr_grl_tiny_c3x4
ImageModelDescriptor(
architecture='GRLIR',
architecture='GRL',
input_channels=3,
output_channels=3,
purpose='SR',
Expand Down
108 changes: 54 additions & 54 deletions tests/test_GRLIR.py → tests/test_GRL.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spandrel.architectures.GRLIR import GRLIR, load
from spandrel.architectures.GRL import GRL, load

from .util import (
ModelFile,
Expand All @@ -9,57 +9,57 @@
)


def test_GRLIR_load():
def test_GRL_load():
assert_loads_correctly(
load,
lambda: GRLIR(),
lambda: GRLIR(in_channels=1, out_channels=3),
lambda: GRLIR(in_channels=4, out_channels=4),
lambda: GRLIR(embed_dim=16),
lambda: GRL(),
lambda: GRL(in_channels=1, out_channels=3),
lambda: GRL(in_channels=4, out_channels=4),
lambda: GRL(embed_dim=16),
# embed_dim=16 makes tests go faster
lambda: GRLIR(embed_dim=16, upsampler="pixelshuffle", upscale=2),
lambda: GRLIR(embed_dim=16, upsampler="pixelshuffle", upscale=3),
lambda: GRLIR(embed_dim=16, upsampler="pixelshuffle", upscale=4),
lambda: GRLIR(embed_dim=16, upsampler="pixelshuffle", upscale=8),
lambda: GRLIR(embed_dim=16, upsampler="pixelshuffledirect", upscale=2),
lambda: GRLIR(embed_dim=16, upsampler="pixelshuffledirect", upscale=3),
lambda: GRLIR(embed_dim=16, upsampler="pixelshuffledirect", upscale=4),
lambda: GRLIR(embed_dim=16, upsampler="pixelshuffledirect", upscale=8),
lambda: GRLIR(embed_dim=16, upsampler="nearest+conv", upscale=4),
lambda: GRLIR(
lambda: GRL(embed_dim=16, upsampler="pixelshuffle", upscale=2),
lambda: GRL(embed_dim=16, upsampler="pixelshuffle", upscale=3),
lambda: GRL(embed_dim=16, upsampler="pixelshuffle", upscale=4),
lambda: GRL(embed_dim=16, upsampler="pixelshuffle", upscale=8),
lambda: GRL(embed_dim=16, upsampler="pixelshuffledirect", upscale=2),
lambda: GRL(embed_dim=16, upsampler="pixelshuffledirect", upscale=3),
lambda: GRL(embed_dim=16, upsampler="pixelshuffledirect", upscale=4),
lambda: GRL(embed_dim=16, upsampler="pixelshuffledirect", upscale=8),
lambda: GRL(embed_dim=16, upsampler="nearest+conv", upscale=4),
lambda: GRL(
embed_dim=16,
depths=[4, 5, 3, 2, 1],
num_heads_window=[2, 3, 5, 1, 3],
num_heads_stripe=[2, 4, 7, 1, 1],
),
lambda: GRLIR(mlp_ratio=2),
lambda: GRLIR(mlp_ratio=3),
lambda: GRLIR(qkv_proj_type="linear", qkv_bias=True),
lambda: GRLIR(qkv_proj_type="linear", qkv_bias=False),
lambda: GRLIR(qkv_proj_type="separable_conv", qkv_bias=True),
lambda: GRLIR(qkv_proj_type="separable_conv", qkv_bias=False),
lambda: GRLIR(conv_type="1conv"),
lambda: GRLIR(conv_type="1conv1x1"),
lambda: GRLIR(conv_type="linear"),
lambda: GRLIR(conv_type="3conv"),
lambda: GRL(mlp_ratio=2),
lambda: GRL(mlp_ratio=3),
lambda: GRL(qkv_proj_type="linear", qkv_bias=True),
lambda: GRL(qkv_proj_type="linear", qkv_bias=False),
lambda: GRL(qkv_proj_type="separable_conv", qkv_bias=True),
lambda: GRL(qkv_proj_type="separable_conv", qkv_bias=False),
lambda: GRL(conv_type="1conv"),
lambda: GRL(conv_type="1conv1x1"),
lambda: GRL(conv_type="linear"),
lambda: GRL(conv_type="3conv"),
# These require non-persistent buffers to be detected
# lambda: GRLIR(
# lambda: GRL(
# window_size=16,
# stripe_size=[32, 64],
# anchor_window_down_factor=1,
# ),
# lambda: GRLIR(
# lambda: GRL(
# window_size=16,
# stripe_size=[32, 64],
# anchor_window_down_factor=2,
# ),
# lambda: GRLIR(
# lambda: GRL(
# window_size=16,
# stripe_size=[32, 64],
# anchor_window_down_factor=4,
# ),
# some actual training configs
lambda: GRLIR(
lambda: GRL(
upscale=4,
img_size=64,
window_size=8,
Expand All @@ -75,7 +75,7 @@ def test_GRLIR_load():
conv_type="1conv",
upsampler="pixelshuffledirect",
),
lambda: GRLIR(
lambda: GRL(
upscale=4,
img_size=64,
window_size=8,
Expand All @@ -91,7 +91,7 @@ def test_GRLIR_load():
conv_type="1conv",
upsampler="pixelshuffle",
),
lambda: GRLIR(
lambda: GRL(
upscale=4,
img_size=64,
window_size=8,
Expand Down Expand Up @@ -128,138 +128,138 @@ def test_GRLIR_load():
)


# def test_GRLIR_dn_grl_tiny_c1(snapshot):
# def test_GRL_dn_grl_tiny_c1(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/dn_grl_tiny_c1.ckpt"
# )
# model = file.load_model()
# assert model == snapshot(exclude=disallowed_props)
# assert isinstance(model.model, GRLIR)
# assert isinstance(model.model, GRL)
# # this model is weird, so no inference test


# def test_GRLIR_dn_grl_base_c1s25(snapshot):
# def test_GRL_dn_grl_base_c1s25(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/dn_grl_base_c1s25.ckpt"
# )
# model = file.load_model()
# assert model == snapshot(exclude=disallowed_props)
# assert isinstance(model.model, GRLIR)
# assert isinstance(model.model, GRL)
# # we don't have grayscale images yet


# def test_GRLIR_jpeg_grl_small_c1q30(snapshot):
# def test_GRL_jpeg_grl_small_c1q30(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/jpeg_grl_small_c1q30.ckpt"
# )
# model = file.load_model()
# assert model == snapshot(exclude=disallowed_props)
# assert isinstance(model.model, GRLIR)
# assert isinstance(model.model, GRL)
# # we don't have grayscale images yet


# def test_GRLIR_dn_grl_small_c3s15(snapshot):
# def test_GRL_dn_grl_small_c3s15(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/dn_grl_small_c3s15.ckpt"
# )
# model = file.load_model()
# assert model == snapshot(exclude=disallowed_props)
# assert isinstance(model.model, GRLIR)
# assert isinstance(model.model, GRL)
# assert_image_inference(
# file,
# model,
# [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
# )


# def test_GRLIR_dn_grl_base_c3s50(snapshot):
# def test_GRL_dn_grl_base_c3s50(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/dn_grl_base_c3s50.ckpt"
# )
# model = file.load_model()
# assert model == snapshot(exclude=disallowed_props)
# assert isinstance(model.model, GRLIR)
# assert isinstance(model.model, GRL)
# assert_image_inference(
# file,
# model,
# [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
# )


# def test_GRLIR_db_motion_grl_base_gopro(snapshot):
# def test_GRL_db_motion_grl_base_gopro(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/db_motion_grl_base_gopro.ckpt"
# )
# model = file.load_model()
# assert model == snapshot(exclude=disallowed_props)
# assert isinstance(model.model, GRLIR)
# assert isinstance(model.model, GRL)
# assert_image_inference(
# file,
# model,
# [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
# )


# def test_GRLIR_jpeg_grl_small_c3(snapshot):
# def test_GRL_jpeg_grl_small_c3(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/jpeg_grl_small_c3.ckpt"
# )
# model = file.load_model()
# assert model == snapshot(exclude=disallowed_props)
# assert isinstance(model.model, GRLIR)
# assert isinstance(model.model, GRL)
# # this model is weird, so no inference test


# def test_GRLIR_jpeg_grl_small_c3q20(snapshot):
# def test_GRL_jpeg_grl_small_c3q20(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/jpeg_grl_small_c3q20.ckpt"
# )
# model = file.load_model()
# assert model == snapshot(exclude=disallowed_props)
# assert isinstance(model.model, GRLIR)
# assert isinstance(model.model, GRL)
# assert_image_inference(
# file,
# model,
# [TestImage.SR_64, TestImage.JPEG_15],
# )


def test_GRLIR_sr_grl_tiny_c3x3(snapshot):
def test_GRL_sr_grl_tiny_c3x3(snapshot):
file = ModelFile.from_url(
"https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/sr_grl_tiny_c3x3.ckpt"
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GRLIR)
assert isinstance(model.model, GRL)
assert_image_inference(
file,
model,
[TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
)


def test_GRLIR_sr_grl_tiny_c3x4(snapshot):
def test_GRL_sr_grl_tiny_c3x4(snapshot):
file = ModelFile.from_url(
"https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/sr_grl_tiny_c3x4.ckpt"
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GRLIR)
assert isinstance(model.model, GRL)
assert_image_inference(
file,
model,
[TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
)


def test_GRLIR_bsr_grl_base(snapshot):
def test_GRL_bsr_grl_base(snapshot):
file = ModelFile.from_url(
"https://drive.google.com/file/d/1JdzeTFiBVSia7PmSvr5VduwDdLnirxAG/view",
name="bsr_grl_base.safetensors",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, GRLIR)
assert isinstance(model.model, GRL)
assert_image_inference(
file,
model,
Expand Down

0 comments on commit 68d43fd

Please sign in to comment.