diff --git a/.gitignore b/.gitignore index 3cb6fca..45df388 100755 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ experiments/* # results/* tb_logger*/* wandb/* -tmp/* +tmp*/* *.sh .vscode* .github diff --git a/README.md b/README.md index 7a01fe9..ad5a03a 100755 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -# QuanTexSR +# FeMaSR This is the official PyTorch codes for the paper -[Blind Image Super Resolution with Semantic-Aware Quantized Texture Prior](https://arxiv.org/abs/2202.13142) +[Real-World Blind Super-Resolution via Feature Matching with Implicit High-Resolution Priors](https://arxiv.org/abs/2202.13142) [Chaofeng Chen\*](https://chaofengc.github.io), [Xinyu Shi\*](https://github.com/Xinyu-Shi), [Yipeng Qin](http://yipengqin.github.io/), [Xiaoming Li](https://csxmli2016.github.io/), [Xiaoguang Han](https://mypage.cuhk.edu.cn/academics/hanxiaoguang/), [Tao Yang](https://github.com/yangxy), [Shihui Guo](http://guoshihui.net/) (\* indicates equal contribution) @@ -9,7 +9,9 @@ This is the official PyTorch codes for the paper ### Update -- **2022.03.02**: Add onedrive download link for pretrained weights. +- **2022.07.02** + - Update codes of the new version FeMaSR + - Please find the old QuanTexSR in the `quantexsr` branch Here are some example results on test images from [BSRGAN](https://github.com/cszn/BSRGAN) and [RealESRGAN](https://github.com/xinntao/Real-ESRGAN). @@ -37,12 +39,12 @@ Here are some example results on test images from [BSRGAN](https://github.com/cs - Other required packages in `requirements.txt` ``` # git clone this repository -git clone https://github.com/chaofengc/QuanTexSR.git -cd QuanTexSR +git clone https://github.com/chaofengc/FeMaSR.git +cd FeMaSR # create new anaconda env -conda create -n quantexsr python=3.8 -source activate quantexsr +conda create -n femasr python=3.8 +source activate femasr # install python dependencies pip3 install -r requirements.txt @@ -51,13 +53,9 @@ python setup.py develop ## Quick Inference -Download pretrained model (**only provide x4 model now**) from -- [BaiduNetdisk](https://pan.baidu.com/s/1H_9TIJUHEgAe75VToknbIA ), extract code `qtsr` . -- [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/chaofeng_chen_staff_main_ntu_edu_sg/EuqbHtP9-f9OjzLpyIftKH0Bp8WVlT-8FNX6-boTeqE47w) - -Test the model with the following script ``` -python inference_quantexsr.py -w ./path/to/model/weight -i ./path/to/test/image[or folder] +python inference_quantexsr.py -s 4 -i ./testset -o results_x4/ +python inference_quantexsr.py -s 2 -i ./testset -o results_x2/ ``` ## Train the model @@ -70,36 +68,23 @@ Please prepare the training and testing data follow descriptions in the main pap #### Model preparation -Before training, you need to put the following pretrained models in `experiments/pretrained_models` and specify their path in the corresponding option file. - -- HQ pretrain stage: pretrained semantic cluster codebook -- LQ stage (SR model training): pretrained semantic aware vqgan, pretrained PSNR oriented RRDB model -- lpips weight for validation - -The above models can be downloaded from the BaiduNetDisk. +Before training, you need to +- Download the pretrained HRP model [here]() +- Put the pretrained models in `experiments/pretrained_models` +- Specify their path in the corresponding option file. ### Train SR model ``` -python basicsr/train.py -opt options/train_QuanTexSR_LQ_stage.yml +python basicsr/train.py -opt options/train_FeMaSR_LQ_stage.yml ``` ### Model pretrain -In case you want to pretrain your own VQGAN prior, we also provide the training instructions below. - -#### Pretrain semantic codebook - -The semantic-aware codebook is obtained with VGG19 features using a mini-batch version of K-means, optimized with Adam. This script will give three levels of codebooks from `relu3_4`, `relu4_4` and `relu5_4` features. We use `relu4_4` for this project. - -``` -python basicsr/train.py -opt options/train_QuanTexSR_semantic_cluster_stage.yml -``` - -#### Pretrain of semantic-aware VQGAN +In case you want to pretrain your own HRP model, we also provide the training option file: ``` -python basicsr/train.py -opt options/train_QuanTexSR_HQ_pretrain_stage.yml +python basicsr/train.py -opt options/train_FeMaSR_HQ_pretrain_stage.yml ``` ## Citation diff --git a/basicsr/archs/femasr_arch.py b/basicsr/archs/femasr_arch.py index 50104d3..163f39b 100644 --- a/basicsr/archs/femasr_arch.py +++ b/basicsr/archs/femasr_arch.py @@ -87,6 +87,8 @@ def forward(self, z, gt_indices=None, current_iter=None): q_latent_loss = torch.mean((z_q - z.detach())**2) if self.LQ_stage and gt_indices is not None: + # codebook_loss = self.dist(z_q, z_q_gt.detach()).mean() \ + # + self.beta * self.dist(z_q_gt.detach(), z) codebook_loss = self.beta * self.dist(z_q_gt.detach(), z) texture_loss = self.gram_loss(z, z_q_gt.detach()) codebook_loss = codebook_loss + texture_loss diff --git a/basicsr/data/bsrgan_train_dataset.py b/basicsr/data/bsrgan_train_dataset.py index 0b33bc4..e8bcbca 100755 --- a/basicsr/data/bsrgan_train_dataset.py +++ b/basicsr/data/bsrgan_train_dataset.py @@ -6,32 +6,11 @@ from basicsr.utils import FileClient, img2tensor from basicsr.utils.registry import DATASET_REGISTRY -import os +from .data_util import make_dataset + import cv2 import random -IMG_EXTENSIONS = [ - '.jpg', '.JPG', '.jpeg', '.JPEG', - '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', - '.tif', '.TIF', '.tiff', '.TIFF', -] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def make_dataset(dir, max_dataset_size=float("inf")): - images = [] - assert os.path.isdir(dir), '%s is not a valid directory' % dir - - for root, _, fnames in sorted(os.walk(dir)): - for fname in fnames: - if is_image_file(fname): - path = os.path.join(root, fname) - images.append(path) - return images[:min(max_dataset_size, len(images))] - def random_resize(img, scale_factor=1.): return cv2.resize(img, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC) diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py index 328c3cb..091f4e7 100755 --- a/basicsr/data/data_util.py +++ b/basicsr/data/data_util.py @@ -1,6 +1,7 @@ import cv2 import numpy as np import torch +import os from os import path as osp from torch.nn import functional as F @@ -8,6 +9,29 @@ from basicsr.utils import img2tensor, scandir +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf"), followlinks=True): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir, followlinks=followlinks)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): """Read a sequence of images from a given folder path. diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index 0d76238..f65edd6 100755 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -8,34 +8,12 @@ from basicsr.utils import FileClient, img2tensor from basicsr.utils.registry import DATASET_REGISTRY - -IMG_EXTENSIONS = [ - '.jpg', '.JPG', '.jpeg', '.JPEG', - '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', - '.tif', '.TIF', '.tiff', '.TIFF', -] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def make_dataset(dir, max_dataset_size=float("inf")): - images = [] - assert os.path.isdir(dir), '%s is not a valid directory' % dir - - for root, _, fnames in sorted(os.walk(dir)): - for fname in fnames: - if is_image_file(fname): - path = os.path.join(root, fname) - images.append(path) - return images[:min(max_dataset_size, len(images))] +from .data_util import make_dataset def random_resize(img, scale_factor=1.): return cv2.resize(img, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC) - @DATASET_REGISTRY.register() class PairedImageDataset(data.Dataset): """Paired image dataset for image restoration. diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index aa8c821..f58b5bf 100755 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -206,7 +206,7 @@ def update_learning_rate(self, current_iter, warmup_iter=-1): self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): - return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + return [optim.param_groups[0]['lr'] for optim in self.optimizers] @master_only def save_network(self, net, net_label, current_iter, param_key='params'): diff --git a/basicsr/models/femasr_model.py b/basicsr/models/femasr_model.py index 348282c..5f5976f 100755 --- a/basicsr/models/femasr_model.py +++ b/basicsr/models/femasr_model.py @@ -77,7 +77,6 @@ def init_training_settings(self): # define network net_d self.net_d = build_network(self.opt['network_d']) self.net_d = self.model_to_device(self.net_d) - # self.print_network(self.net_d) # load pretrained d models load_path = self.opt['path'].get('pretrain_network_d', None) # print(load_path) @@ -118,23 +117,16 @@ def setup_optimizers(self): logger = get_root_logger() logger.warning(f'Params {k} will not be optimized.') + # optimizer g optim_type = train_opt['optim_g'].pop('type') - if optim_type == 'Adam': - self.optimizer_g = torch.optim.Adam(optim_params, - **train_opt['optim_g']) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') + optim_class = getattr(torch.optim, optim_type) + self.optimizer_g = optim_class(optim_params, **train_opt['optim_g']) self.optimizers.append(self.optimizer_g) # optimizer d optim_type = train_opt['optim_d'].pop('type') - if optim_type == 'Adam': - self.optimizer_d = torch.optim.Adam(self.net_d.parameters(), - **train_opt['optim_d']) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') + optim_class = getattr(torch.optim, optim_type) + self.optimizer_d = optim_class(self.net_d.parameters(), **train_opt['optim_d']) self.optimizers.append(self.optimizer_d) def feed_data(self, data): diff --git a/framework_overview.png b/framework_overview.png index 2fbd180..1de9947 100644 Binary files a/framework_overview.png and b/framework_overview.png differ diff --git a/generate_dataset.py b/generate_dataset.py new file mode 100755 index 0000000..628975e --- /dev/null +++ b/generate_dataset.py @@ -0,0 +1,68 @@ +import os +import cv2 +import numpy as np +import random +from tqdm import tqdm +from multiprocessing import Pool + +from basicsr.data.bsrgan_util import degradation_bsrgan + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf"), followlinks=True): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir, followlinks=followlinks)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + +def degrade_img(hr_path, save_path): + img_gt = cv2.imread(hr_path).astype(np.float32) / 255. + img_gt = img_gt[:, :, [2, 1, 0]] # BGR to RGB + img_lq, img_gt = degradation_bsrgan(img_gt, sf=scale, use_crop=False) + img_lq = (img_lq[:, :, [2, 1, 0]] * 255).astype(np.uint8) + print(f'Save {save_path}') + cv2.imwrite(save_path, img_lq) + + +seed = 123 +random.seed(seed) +np.random.seed(seed) + +# scale = 2 +scale = 4 +hr_img_list = make_dataset('../datasets/HQ_sub') +pool = Pool(processes=40) + +# hr_img_list = ['../datasets/HQ_sub_samename/DIV8K_train_HR_sub/div8k_1383_s021.png'] + +# scale = 2 +# hr_img_list = ['../datasets/HQ_sub_samename/DIV8K_train_HR_sub/div8k_0903_s056.png'] + +# scale = 4 +# hr_img_list = make_dataset('../datasets/LQ_sub_samename_X4') + +for hr_path in hr_img_list: + save_path = hr_path.replace('HQ_sub', f'LQ_sub_X{scale}') + save_path = save_path.replace('HR', 'LR') + save_dir = os.path.dirname(save_path) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + pool.apply_async(degrade_img(hr_path, save_path)) + +pool.close() +pool.join() + diff --git a/options/train_FeMaSR_HQ_pretrain_stage.yml b/options/train_FeMaSR_HQ_pretrain_stage.yml index 9a4eed9..23e5b09 100755 --- a/options/train_FeMaSR_HQ_pretrain_stage.yml +++ b/options/train_FeMaSR_HQ_pretrain_stage.yml @@ -1,5 +1,5 @@ # general settings -name: 004_FeMaSR_HQ_stage +name: 008_FeMaSR_HQ_stage # name: debug_FeMaSR model_type: FeMaSRModel scale: 4 @@ -22,11 +22,12 @@ datasets: # data loader use_shuffle: true - num_worker_per_gpu: 2 - batch_size_per_gpu: 12 + batch_size_per_gpu: &bsz 12 + num_worker_per_gpu: *bsz dataset_enlarge_ratio: 1 - prefetch_mode: ~ + prefetch_mode: cpu + num_prefetch_queue: *bsz val: name: General_Image_Valid @@ -58,8 +59,8 @@ network_d: path: # pretrain_network_g: ./experiments/pretrained_models/QuanTexSR/pretrain_semantic_vqgan_net_g_latest.pth # pretrain_network_d: ~ - pretrain_network_g: ./experiments/003_FeMaSR_HQ_stage/models/net_g_best_.pth - pretrain_network_d: ./experiments/003_FeMaSR_HQ_stage/models/net_d_best_.pth + pretrain_network_g: ./experiments/004_FeMaSR_HQ_stage/models/net_g_best_.pth + # pretrain_network_d: ./experiments/004_FeMaSR_HQ_stage/models/net_d_best_.pth strict_load: false # resume_state: ~ @@ -117,7 +118,7 @@ val: key_metric: lpips metrics: - psnr: # metric name, can be arbitrary + psnr: # metric name, not used in this codebase type: psnr crop_border: 4 test_y_channel: true diff --git a/options/train_FeMaSR_LQ_stage.yml b/options/train_FeMaSR_LQ_stage.yml index 22b278f..45e1c5a 100755 --- a/options/train_FeMaSR_LQ_stage.yml +++ b/options/train_FeMaSR_LQ_stage.yml @@ -1,5 +1,5 @@ # general settings -name: 003_FeMaSR_LQ_stage +name: 014_FeMaSR_LQ_stage # name: debug_FeMaSR model_type: FeMaSRModel scale: &upscale 4 @@ -12,6 +12,9 @@ datasets: name: General_Image_Train type: BSRGANTrainDataset dataroot_gt: ../datasets/HQ_sub + # type: PairedImageDataset + # dataroot_gt: ../datasets/HQ_sub + # dataroot_lq: ../datasets/LQ_sub_X4 io_backend: type: disk @@ -57,9 +60,9 @@ network_d: # path path: - pretrain_network_hq: ./experiments/pretrained_models/QuanTexSR/pretrain_semantic_vqgan_net_g_latest.pth - pretrain_network_g: ./experiments/0230_QSR_Semantic_RRDB_LQ_stage_X4/models/net_g_best_.pth - pretrain_network_d: ./experiments/0230_QSR_Semantic_RRDB_LQ_stage_X4/models/net_d_best_.pth + pretrain_network_hq: ./experiments/008_FeMaSR_HQ_stage/models/net_g_best_.pth + pretrain_network_g: ~ + pretrain_network_d: ./experiments/008_FeMaSR_HQ_stage/models/net_d_best_.pth strict_load: false # resume_state: ~ @@ -72,7 +75,7 @@ train: betas: [0.9, 0.99] optim_d: type: Adam - lr: !!float 1e-4 + lr: !!float 4e-4 weight_decay: 0 betas: [0.9, 0.99] diff --git a/vis_codebook.py b/vis_codebook.py new file mode 100644 index 0000000..9b439a5 --- /dev/null +++ b/vis_codebook.py @@ -0,0 +1,98 @@ +from itertools import count +from tokenize import PlainToken +import torch +import torchvision.transforms as tf +from torchvision.utils import save_image + +import numpy as np +import os +import random +from tqdm import tqdm +import cv2 +from matplotlib import pyplot as plt +import seaborn as sns + +from basicsr.utils.misc import set_random_seed +from basicsr.utils import img2tensor, tensor2img, imwrite +from basicsr.archs.femasr_arch import FeMaSRNet + + +def reconstruct_ost(model, data_dir, save_dir, maxnum=100): + + texture_classes = list(os.listdir(data_dir)) + texture_classes.remove('manga109') + code_idx_dict = {} + for tc in texture_classes: + img_name_list = os.listdir(os.path.join(data_dir, tc)) + random.shuffle(img_name_list) + tmp_code_idx_list = [] + for img_name in tqdm(img_name_list[:maxnum]): + img_path = os.path.join(data_dir, tc, img_name) + + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + img_tensor = img2tensor(img).to(device) / 255. + img_tensor = img_tensor.unsqueeze(0) + + rec, _, _, indices = model(img_tensor) + indices = indices[0] + + save_path = os.path.join(save_dir, tc, img_name) + if not os.path.exists(os.path.join(save_dir, tc)): + os.makedirs(os.path.join(save_dir, tc), exist_ok=True) + imwrite(tensor2img(rec), save_path) + + save_org_dir = save_dir.replace('rec', 'org') + save_org_path = os.path.join(save_org_dir, tc, img_name) + if not os.path.exists(os.path.join(save_org_dir, tc)): + os.makedirs(os.path.join(save_org_dir, tc), exist_ok=True) + imwrite(tensor2img(img_tensor), save_org_path) + + tmp_code_idx_list.append(indices) + code_idx_dict[tc] = tmp_code_idx_list + + torch.save(code_idx_dict, './tmp_code_vis/code_idx_dict.pth') + + +def vis_hrp(model, code_list_path, save_dir, samples_each_class=16): + code_idx_dict = torch.load(code_list_path) + classes = list(code_idx_dict.keys()) + + latent_size = 8 + color_palette = sns.color_palette() + for idx, (key, value) in enumerate(code_idx_dict.items()): + all_idx = torch.cat([x.flatten() for x in value]) + + plt.figure(figsize=(16, 8)) + sns.histplot(all_idx.cpu().numpy(), color=color_palette[idx]) + plt.xlabel(key, fontsize=30) + plt.ylabel('Count', fontsize=30) + plt.savefig(f'./tmp_code_vis/code_stat/code_index_bincount_{key}.pdf') + + counts = all_idx.bincount() + dist = counts / sum(counts) + dist = dist.cpu().numpy() + + vis_tex_samples = [] + for sid in range(32): + vis_tex_map = np.random.choice(np.arange(dist.shape[0]), latent_size ** 2, p=dist) + vis_tex_map = torch.from_numpy(vis_tex_map).to(all_idx) + vis_tex_map = vis_tex_map.reshape(1, 1, latent_size, latent_size) + vis_tex_img = model.decode_indices(vis_tex_map) + vis_tex_samples.append(vis_tex_img) + vis_tex_samples = torch.cat(vis_tex_samples, dim=0) + save_image(vis_tex_samples, f'./tmp_code_vis/tmp_tex_vis/{key}.jpg', normalize=True, nrow=16) + +if __name__ == '__main__': + # set random seeds to ensure reproducibility + set_random_seed(123) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # set up the model + weight_path = './experiments/pretrained_models/QuanTexSR/pretrain_semantic_vqgan_net_g_latest.pth' + vqgan = FeMaSRNet(codebook_params=[[32, 1024, 512]], LQ_stage=False).to(device) + vqgan.load_state_dict(torch.load(weight_path)['params'], strict=False) + vqgan.eval() + + reconstruct_ost(vqgan, '../datasets/SR_OST_datasets/OutdoorSceneTrain_v2/', './tmp_code_vis/ost_rec', maxnum=1000) + vis_hrp(vqgan, './tmp_code_vis/code_idx_dict.pth', './tmp_code_vis/') diff --git a/visual_codebook.py b/visual_codebook.py new file mode 100755 index 0000000..2dd4556 --- /dev/null +++ b/visual_codebook.py @@ -0,0 +1,315 @@ +from PIL import Image + +import torch +import torchvision.transforms as tf +from torchvision.utils import save_image +from vqgan_vis_arch import MultiScaleVQVAESemanticHQ + +import numpy as np +import os +import random +from tqdm import tqdm +from kmeans_pytorch import kmeans +import cv2 + +# set random seeds to ensure reproducibility +seed = 123 +np.random.seed(seed) +random.seed(seed) +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) + + +color_table = np.array([ + [153, 153, 153], # 0, background + [0, 255, 255], # 1, sky + [109, 158, 235], # 2, water + [183, 225, 205], # 3, grass + [153, 0, 255], # 4, mountain + [17, 85, 204], # 5, building + [106, 168, 79], # 6, plant + [224, 102, 102], # 7, animal + [255, 255, 255], # 8/255, void + [153, 153, 153], # 0, background + [0, 255, 255], # 1, sky + # [109, 158, 235], # 2, water + # [183, 225, 205], # 3, grass + # [153, 0, 255], # 4, mountain + # [17, 85, 204], # 5, building + # [106, 168, 79], # 6, plant + # [224, 102, 102], # 7, animal + # [153, 153, 153], # 0, background + # [0, 255, 255], # 1, sky + # [109, 158, 235], # 2, water + # [183, 225, 205], # 3, grass + # [153, 0, 255], # 4, mountain + # [17, 85, 204], # 5, building + # [106, 168, 79], # 6, plant + # [224, 102, 102], # 7, animal + ]) + + +def index_to_color(index_map): + """shape: (H, W)""" + color_map = np.zeros((index_map.shape[0], index_map.shape[1], 3)) + for i in range(color_table.shape[0]): + color_map[index_map == i] = color_table[i] + return color_map + + +def cluster_codebook(model, num_clusters, codebook=None): + if codebook is None: + codebook = model.quantize_group[0].embedding.weight + else: + codebook = codebook.squeeze() + cluster_ids_x, cluster_centers = kmeans(X=codebook, num_clusters=num_clusters, distance='euclidean', iter_limit=500, device=torch.device('cuda')) + # cluster_ids_x, cluster_centers = kmeans(X=codebook, num_clusters=num_clusters, distance='cosine', iter_limit=500, device=torch.device('cuda')) + return cluster_ids_x, cluster_centers + +def vis_single_code(model, code_idx, up_factor=1): + input_tensor = torch.randn(code_idx.shape[0], 3, 16*up_factor, 16*up_factor).cuda() + code_idx = code_idx.repeat_interleave(up_factor**2) + outputs, _, _, _ = model(input_tensor, gt_indices=[code_idx.cuda()]) + output_img = outputs[-1] + return output_img.clamp(0, 1) + +def vis_tiled_single_code(model, code_idx): + input_tensor = torch.randn(code_idx.shape[0], 3, 256, 256).cuda() + outputs, _, _, _ = model(input_tensor, gt_indices=[code_idx.repeat(16*16).cuda()]) + output_img = outputs[-1] + return output_img + +def vis_cluster_code(cluster_ids_x, model): + num_clusters = torch.max(cluster_ids_x) + new_order = [] + for i in range(num_clusters + 1): + tmp_idx = torch.nonzero(cluster_ids_x == i) + new_order.append(tmp_idx) + print(f'cluster {i}, shape', tmp_idx.shape) + new_order = torch.cat(new_order) + vis_img = vis_single_code(model, new_order) + return new_order, vis_img + +def get_useful_code(cluster_ids_x, model): + num_clusters = torch.max(cluster_ids_x) + new_order = [] + for i in range(num_clusters + 1): + # if not i in [24]: + if i in [6]: + tmp_idx = torch.nonzero(cluster_ids_x == i) + new_order.append(tmp_idx) + print(f'cluster {i}, shape', tmp_idx.shape) + new_order = torch.cat(new_order) + codebook = model.quantize_group[0].embedding.weight + new_codebook = codebook[new_order] + torch.save(new_codebook, 'useful_code_semantic.pth') + print(new_codebook.shape) + +def vis_random_codes(code_num, sample_num, model): + output_imgs = [] + all_idx = torch.arange(16) + for i in range(sample_num): + selected_idx = all_idx[torch.randint(all_idx.shape[0], (code_num,))] + # tmp_idx = torch.randperm(16) + # selected_idx = all_idx[tmp_idx[:code_num]] + print(code_num, i, selected_idx.squeeze().numpy()) + sampled_code_idx = selected_idx[torch.randint(selected_idx.shape[0], (16*16,))] + input_tensor = torch.randn(1, 3, 256, 256).cuda() + outputs, _, _, _ = model(input_tensor, gt_indices=[sampled_code_idx.cuda()]) + output_img = outputs[-1] + output_imgs.append(output_img) + output_img = torch.cat(output_imgs, dim=0) + return output_img + + +def vis_seq_codes(code_num, sample_num, model): + output_imgs = [] + all_idx = torch.arange(16) + for i in range(sample_num): + selected_idx = torch.arange(code_num) + sampled_code_idx = selected_idx.repeat(256//(i+1) + 1)[:256] + input_tensor = torch.randn(1, 3, 256, 256).cuda() + outputs, _, _, _ = model(input_tensor, gt_indices=[sampled_code_idx.cuda()]) + output_img = outputs[-1] + output_imgs.append(output_img) + output_img = torch.cat(output_imgs, dim=0) + return output_img + + +def vis_cluster_samples(cluster_ids_x, cluster_id, model, sample_num=1): + selected_idx = torch.nonzero(cluster_ids_x == cluster_id) + print(cluster_id, selected_idx.squeeze().cpu().numpy()) + output_imgs = [] + for i in range(sample_num): + sampled_code_idx = selected_idx[torch.randint(selected_idx.shape[0], (16*16,))] + input_tensor = torch.randn(1, 3, 256, 256).cuda() + outputs, _, _, _ = model(input_tensor, gt_indices=[sampled_code_idx.cuda()]) + output_img = outputs[-1] + output_imgs.append(output_img) + output_img = torch.cat(output_imgs, dim=0) + return output_img + + +def vis_given_code_nums(code_num, model, sample_num=8): + output_imgs = [] + all_idx = torch.arange(19) + for i in range(sample_num): + if len(code_num) > 0: + selected_idx = code_num + else: + selected_idx = all_idx[torch.randint(all_idx.shape[0], (code_num,))] + print(i, selected_idx) + sampled_code_idx = selected_idx[torch.randint(selected_idx.shape[0], (16*16,))] + input_tensor = torch.randn(1, 3, 256, 256).cuda() + outputs, _, _, _ = model(input_tensor, gt_indices=[sampled_code_idx.cuda()]) + output_img = outputs[-1] + output_imgs.append(output_img) + output_img = torch.cat(output_imgs, dim=0) + return output_img + +def vis_given_code_list(code_list, model, sample_num=8): + output_imgs = [] + for i in range(sample_num): + selected_idx = torch.tensor(code_list) + sampled_code_idx = selected_idx[torch.randint(selected_idx.shape[0], (16*16,))] + input_tensor = torch.randn(1, 3, 256, 256).cuda() + outputs, _, _, _ = model(input_tensor, gt_indices=[sampled_code_idx.cuda()]) + output_img = outputs[-1] + output_imgs.append(output_img) + output_img = torch.cat(output_imgs, dim=0) + return output_img + +def read_img_tensor(img_path): + img = Image.open(img_path) + img_tensor = tf.functional.to_tensor(img) + return img_tensor.unsqueeze(0).cuda() + +viscode = False +semantic = False +# semantic = False +save_suffix = 'semantic' if semantic else 'nosemantic' + +if semantic: + weight_path = './experiments/0001_VQGAN_SemanticGuide_RRDBFuse_HQ_stage/models/net_g_290000.pth' + model = MultiScaleVQVAESemanticHQ(with_semantic=True).cuda() +else: + # weight_path = './experiments/101_1_scale_HQ_stage_1024_codebook_largeDataset/models/net_g_110000.pth' + weight_path = './experiments/0005_VQGAN_MultiscaleNoSemantic_NoAttn_RRDBFuse_HQ_stage/models/net_g_latest.pth' + model = MultiScaleVQVAESemanticHQ(act_type='gelu', codebook_params=[[16, 1024, 512]]).cuda() + +model.load_state_dict(torch.load(weight_path)['params'], strict=True) + +# img = vis_given_code_nums(torch.tensor([0, 8, 15, 5, 6, 7]), model, 8) +# save_image(img, f'../tmp_visdir/vis_test_given_id_{save_suffix}.png', nrow=8) +# exit() +# num_clusters = 10 +# cluster_ids_x, cluster_centers = cluster_codebook(model, num_clusters) +# new_order, vis_img = vis_cluster_code(cluster_ids_x, model) +# save_image(vis_img, f'vis_clustered_codes_{save_suffix}.png', nrow=32) +# for i in range(num_clusters): +# for i in range(1): + # tmp_img = vis_cluster_samples(cluster_ids_x, i, model) + # save_image(tmp_img, f'../tmp_visdir/vis_sample_cluster{i}_{save_suffix}.png') + +save_root = './tmp_vis/' + +# up_factor = 1 +# code_vis = vis_single_code(model, torch.arange(1024), up_factor) +# save_image(code_vis, f'{save_root}/vis_useful_codes_{up_factor}_{save_suffix}.png', nrow=32) +# exit() + +# for up_factor in range(1, 9): + # code_vis = vis_single_code(model, torch.arange(19), up_factor) + # code_vis = torch.nn.functional.interpolate(code_vis, (256, 256), mode='bicubic') + # save_image(code_vis, f'../tmp_visdir/vis_useful_codes_{up_factor}_{save_suffix}.png', nrow=8) + +# vis_imgs = [] +# sample_num = 16 +# for i in range(4, 8): + # code_num = i + 1 + # tmp_img = vis_random_codes(code_num, sample_num, model) + # vis_imgs.append(tmp_img) +# save_image(torch.cat(vis_imgs, dim=0), f'{save_root}/vis_random_sample_codes_{save_suffix}.png', nrow=sample_num) + +# semantic_indexes = [ + # [1, 2, 5, 7, 13, 14], + # [2, 8, 7, 10, 12, 13, 14], + # [3, 5, 7, 14], + # [1, 4, 6, 9, 10, 12], + # [1, 3, 4, 6, 13, 15] + # ] +# for sl in semantic_indexes: + # print(' '.join([chr(x+ord('a')) for x in sl])) +# exit() +# sample_num = 8 +# for sidx, idx_list in enumerate(semantic_indexes): + # tmp_img = vis_given_code_list(idx_list, model, sample_num=sample_num) + # save_image(tmp_img, f'{save_root}/vis_semantic_codes_sample_{sidx}.png', nrow=sample_num) + +# vis_imgs = [] +# sample_num = 1 +# for i in range(8): + # tmp_img = vis_seq_codes(i+1, sample_num, model) + # vis_imgs.append(tmp_img) +# save_image(torch.cat(vis_imgs, dim=0), f'../tmp_visdir/vis_seq_tiled_codes_{save_suffix}.png', nrow=sample_num) + +# get_useful_code(cluster_ids_x, model) + +## ================= visualize semantic code center ============== +# num_clusters = 8 +# cluster_ids_x, cluster_centers = cluster_codebook(model, num_clusters, codebook=torch.load('./useful_code_semantic.pth')) +# vis_imgs = [] +# for i in range(num_clusters): + # out_img = vis_cluster_samples(cluster_ids_x, i, model, sample_num=8) + # vis_imgs.append(out_img) +# save_image(torch.cat(vis_imgs, dim=0), f'../tmp_visdir/vis_random_sample_center_{save_suffix}.png', nrow=8) +## ================= visualize semantic code center ============== + +## ================= visualize code index ======================= +# img_root = '../datasets/test_real_datasets_mod16/OutdoorSceneTest300/lrx4/' +# for img_name in tqdm(os.listdir(img_root)): + # img_path = os.path.join(img_root, img_name) + # img_tensor = read_img_tensor(img_path) + # outputs, _, _, _ = model(img_tensor) + # output_img = outputs[-1] + # save_image(output_img, f'../tmp_visdir/rec/vis_vqgan_rec_{img_name}') + +# img_path = '../tmp_visdir/OST_001.png' +# gt_label = cv2.imread(img_path)[:, :, 0] +# color_gt_label = index_to_color(gt_label) +# cv2.imwrite('../tmp_visdir/OST_001_color.png', color_gt_label.astype(np.uint8)) + +# img_path = '../datasets/test_real_datasets_mod16/OutdoorSceneTest300/lrx4/OST_001.png' +# num_clusters = 4 +# cluster_ids_x, cluster_centers = cluster_codebook(model, num_clusters, codebook=torch.load('./useful_code_semantic.pth')) + +# img_root = '../datasets/test_real_datasets_mod16/OutdoorSceneTest300/lrx4/' +# for img_name in tqdm(os.listdir(img_root)): + # img_path = os.path.join(img_root, img_name) + # img_tensor = read_img_tensor(img_path) + # outputs, _, _, match_indices = model(img_tensor) + # match_indice = match_indices[0].squeeze() + + # clustered_index = cluster_ids_x[match_indice.view(-1)].reshape_as(match_indice) + # clustered_index = clustered_index.cpu().numpy().astype(np.uint8) + # clustered_index = cv2.resize(clustered_index, None, fx=16, fy=16, interpolation=cv2.INTER_CUBIC) + # colored_clustered_index = index_to_color(clustered_index.astype(np.int)) + # cv2.imwrite(f'../tmp_visdir/rec_index_maps/{img_name}', colored_clustered_index.astype(np.uint8)) +## ================= visualize code index ======================= +# exit() + +img_path = '../datasets/test_datasets/Set5/gt_mod16/baby.png' +img_tensor = read_img_tensor(img_path) +if viscode: + img_tensor = torch.nn.functional.interpolate(img_tensor, (16, 16)) +else: + img_tensor = torch.nn.functional.interpolate(img_tensor, (256, 256)) + +with torch.no_grad(): + if viscode: + save_image(output_img, f'tmp_visdir/vis_codebook_{save_suffix}.png', nrow=32) + else: + outputs, _, _, _ = model(img_tensor) + output_img = outputs[-1] + # save_image(output_img, f'vis_vqgan_rec_fullcode_{save_suffix}.png') + save_image(output_img, f'./tmp_vis/vis_vqgan_rec_usefulcode_{save_suffix}.png')