Skip to content

Commit

Permalink
Optimize FeMaSR for inference (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Mar 23, 2024
1 parent 56b4fbc commit 17930c9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 488 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ def _clean_state_dict(state_dict: StateDict):

keys = list(state_dict.keys())
for k in keys:
if k.startswith(("sft_fusion_group.", "multiscale_encoder.upsampler.")):
if k.startswith(
(
"sft_fusion_group.",
"multiscale_encoder.upsampler.",
"conv_semantic.",
"vgg_feat_extractor.",
)
):
del state_dict[k]


Expand Down Expand Up @@ -57,11 +64,9 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FeMaSR]:
act_type = "silu"
use_quantize = True # cannot be deduced from state_dict
# scale_factor = 4
# use_semantic_loss = False
use_residual = True # cannot be deduced from state_dict

in_channel = state_dict["multiscale_encoder.in_conv.weight"].shape[1]
use_semantic_loss = "conv_semantic.0.weight" in state_dict

# gt_resolution can be derived from the decoders
# we assume that gt_resolution is a power of 2
Expand Down Expand Up @@ -123,7 +128,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FeMaSR]:
act_type=act_type,
use_quantize=use_quantize,
scale_factor=scale_factor,
use_semantic_loss=use_semantic_loss,
use_residual=use_residual,
)

Expand All @@ -141,5 +145,4 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[FeMaSR]:
input_channels=in_channel,
output_channels=in_channel,
size_requirements=SizeRequirements(multiple_of=multiple_of),
call_fn=lambda model, image: model(image)[0],
)
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations

import math

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn as nn

from spandrel.architectures.SwinIR.arch.SwinIR import RSTB
from spandrel.util import store_hyperparameters

from .fema_utils import CombineQuantBlock, ResBlock
from .vgg_arch import VGGFeatureExtractor


class VectorQuantizer(nn.Module):
Expand All @@ -26,12 +22,10 @@ class VectorQuantizer(nn.Module):
_____________________________________________
"""

def __init__(self, n_e, e_dim, beta=0.25, LQ_stage=False):
def __init__(self, n_e, e_dim):
super().__init__()
self.n_e = int(n_e)
self.e_dim = int(e_dim)
self.LQ_stage = LQ_stage
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

Expand All @@ -42,17 +36,7 @@ def dist(self, x, y):
- 2 * torch.matmul(x, y.t())
)

def gram_loss(self, x, y):
b, h, w, c = x.shape
x = x.reshape(b, h * w, c)
y = y.reshape(b, h * w, c)

gmx = x.transpose(1, 2) @ x / (h * w)
gmy = y.transpose(1, 2) @ y / (h * w)

return (gmx - gmy).square().mean()

def forward(self, z, gt_indices=None, current_iter=None):
def forward(self, z):
"""
Args:
z: input features to be quantized, z (continuous) -> z_q (discrete)
Expand All @@ -74,54 +58,16 @@ def forward(self, z, gt_indices=None, current_iter=None):
).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)

if gt_indices is not None:
gt_indices = gt_indices.reshape(-1)

gt_min_indices = gt_indices.reshape_as(min_encoding_indices)
gt_min_onehot = torch.zeros(gt_min_indices.shape[0], codebook.shape[0]).to(
z
)
gt_min_onehot.scatter_(1, gt_min_indices, 1)

z_q_gt = torch.matmul(gt_min_onehot, codebook)
z_q_gt = z_q_gt.view(z.shape)

# get quantized latent vectors
z_q = torch.matmul(min_encodings, codebook)
z_q = z_q.view(z.shape)

e_latent_loss = torch.mean((z_q.detach() - z) ** 2)
q_latent_loss = torch.mean((z_q - z.detach()) ** 2)

if self.LQ_stage and gt_indices is not None:
codebook_loss = self.beta * ((z_q_gt.detach() - z) ** 2).mean() # type: ignore
texture_loss = self.gram_loss(z, z_q_gt.detach()) # type: ignore
codebook_loss = codebook_loss + texture_loss
else:
codebook_loss = q_latent_loss + e_latent_loss * self.beta

# preserve gradients
z_q = z + (z_q - z).detach()

# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()

return (
z_q,
codebook_loss,
min_encoding_indices.reshape(z_q.shape[0], 1, z_q.shape[2], z_q.shape[3]),
)

def get_codebook_entry(self, indices):
b, _, h, w = indices.shape

indices = indices.flatten().to(self.embedding.weight.device)
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
min_encodings.scatter_(1, indices[:, None], 1)

# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
z_q = z_q.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
return z_q


Expand Down Expand Up @@ -211,11 +157,11 @@ def __init__(

self.LQ_stage = LQ_stage

def forward(self, input):
def forward(self, input: torch.Tensor):
outputs = []
x = self.in_conv(input)

for _idx, m in enumerate(self.blocks):
for m in self.blocks:
x = m(x)
outputs.append(x)

Expand Down Expand Up @@ -252,7 +198,6 @@ def __init__(
act_type="silu",
use_quantize=True,
scale_factor=1,
use_semantic_loss=False,
use_residual=True,
):
super().__init__()
Expand Down Expand Up @@ -313,7 +258,6 @@ def __init__(
quantize = VectorQuantizer(
codebook_emb_num[scale],
codebook_emb_dim[scale],
LQ_stage=self.LQ_stage,
)
self.quantize_group.append(quantize)

Expand All @@ -334,31 +278,13 @@ def __init__(
CombineQuantBlock(comb_quant_in_ch1, comb_quant_in_ch2, scale_in_ch)
)

# semantic loss for HQ pretrain stage
self.use_semantic_loss = use_semantic_loss
if use_semantic_loss:
self.conv_semantic = nn.Sequential(
nn.Conv2d(512, 512, 1, 1, 0),
nn.ReLU(),
)
self.vgg_feat_layer = "relu4_4"
self.vgg_feat_extractor = VGGFeatureExtractor([self.vgg_feat_layer])

def encode_and_decode(self, input, gt_indices=None, current_iter=None):
def encode_and_decode(self, input):
enc_feats = self.multiscale_encoder(input.detach())
if self.LQ_stage:
enc_feats = enc_feats[-3:]
else:
enc_feats = enc_feats[::-1]

if self.use_semantic_loss:
with torch.no_grad():
vgg_feat = self.vgg_feat_extractor(input)[self.vgg_feat_layer]

codebook_loss_list = []
indices_list = []
semantic_loss_list = []

quant_idx = 0
prev_dec_feat = None
prev_quant_feat = None
Expand All @@ -372,19 +298,7 @@ def encode_and_decode(self, input, gt_indices=None, current_iter=None):
before_quant_feat = enc_feats[i]
feat_to_quant = self.before_quant_group[quant_idx](before_quant_feat)

if gt_indices is not None:
z_quant, codebook_loss, indices = self.quantize_group[quant_idx](
feat_to_quant, gt_indices[quant_idx]
)
else:
z_quant, codebook_loss, indices = self.quantize_group[quant_idx](
feat_to_quant
)

if self.use_semantic_loss:
semantic_z_quant = self.conv_semantic(z_quant)
semantic_loss = F.mse_loss(semantic_z_quant, vgg_feat) # type: ignore
semantic_loss_list.append(semantic_loss)
z_quant = self.quantize_group[quant_idx](feat_to_quant)

if not self.use_quantize:
z_quant = feat_to_quant
Expand All @@ -393,9 +307,6 @@ def encode_and_decode(self, input, gt_indices=None, current_iter=None):
z_quant, prev_quant_feat
)

codebook_loss_list.append(codebook_loss)
indices_list.append(indices)

quant_idx += 1
prev_quant_feat = z_quant
x = after_quant_feat
Expand All @@ -410,135 +321,8 @@ def encode_and_decode(self, input, gt_indices=None, current_iter=None):

out_img = self.out_conv(x)

codebook_loss = sum(codebook_loss_list)
semantic_loss = (
sum(semantic_loss_list) if len(semantic_loss_list) else codebook_loss * 0
)

return out_img, codebook_loss, semantic_loss, indices_list

def decode_indices(self, indices):
assert (
len(indices.shape) == 4
), f"shape of indices must be (b, 1, h, w), but got {indices.shape}"

z_quant = self.quantize_group[0].get_codebook_entry(indices)
x = self.after_quant_group[0](z_quant)

for m in self.decoder_group:
x = m(x)
out_img = self.out_conv(x)
return out_img

@torch.no_grad() # type: ignore
def test_tile(self, input, tile_size=240, tile_pad=16):
# return self.test(input)
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
"""
batch, channel, height, width = input.shape
output_height = height * self.scale_factor
output_width = width * self.scale_factor
output_shape = (batch, channel, output_height, output_width)

# start with black image
output = input.new_zeros(output_shape)
tiles_x = math.ceil(width / tile_size)
tiles_y = math.ceil(height / tile_size)

# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * tile_size
ofs_y = y * tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + tile_size, height)

# input tile area on total image with padding
input_start_x_pad = max(input_start_x - tile_pad, 0)
input_end_x_pad = min(input_end_x + tile_pad, width)
input_start_y_pad = max(input_start_y - tile_pad, 0)
input_end_y_pad = min(input_end_y + tile_pad, height)

# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
_tile_idx = y * tiles_x + x + 1
input_tile = input[
:,
:,
input_start_y_pad:input_end_y_pad,
input_start_x_pad:input_end_x_pad,
]

# upscale tile
output_tile = self.test(input_tile)

# output tile area on total image
output_start_x = input_start_x * self.scale_factor
output_end_x = input_end_x * self.scale_factor
output_start_y = input_start_y * self.scale_factor
output_end_y = input_end_y * self.scale_factor

# output tile area without padding
output_start_x_tile = (
input_start_x - input_start_x_pad
) * self.scale_factor
output_end_x_tile = (
output_start_x_tile + input_tile_width * self.scale_factor
)
output_start_y_tile = (
input_start_y - input_start_y_pad
) * self.scale_factor
output_end_y_tile = (
output_start_y_tile + input_tile_height * self.scale_factor
)

# put tile into output image
output[
:, :, output_start_y:output_end_y, output_start_x:output_end_x
] = output_tile[
:,
:,
output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile,
]
return output

@torch.no_grad() # type: ignore
def test(self, input):
org_use_semantic_loss = self.use_semantic_loss
self.use_semantic_loss = False

# padding to multiple of window_size * 8
wsz = 8 // self.scale_factor * 8
_, _, h_old, w_old = input.shape
h_pad = (h_old // wsz + 1) * wsz - h_old
w_pad = (w_old // wsz + 1) * wsz - w_old
input = torch.cat([input, torch.flip(input, [2])], 2)[:, :, : h_old + h_pad, :]
input = torch.cat([input, torch.flip(input, [3])], 3)[:, :, :, : w_old + w_pad]

dec, _, _, _ = self.encode_and_decode(input)

output = dec
output = output[..., : h_old * self.scale_factor, : w_old * self.scale_factor]

self.use_semantic_loss = org_use_semantic_loss
return output

def forward(self, input, gt_indices=None):
if gt_indices is not None:
# in LQ training stage, need to pass GT indices for supervise.
dec, codebook_loss, semantic_loss, indices = self.encode_and_decode( # type: ignore
input, gt_indices
)
else:
# in HQ stage, or LQ test stage, no GT indices needed.
dec, codebook_loss, semantic_loss, indices = self.encode_and_decode(input) # type: ignore

return dec, codebook_loss, semantic_loss, indices
def forward(self, input):
# in HQ stage, or LQ test stage, no GT indices needed.
return self.encode_and_decode(input)
Loading

0 comments on commit 17930c9

Please sign in to comment.