diff --git a/README.md b/README.md index 2ee57eab..d7ed9147 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar - [RGT](https://github.com/zhengchen1999/RGT) | [RGT Models](https://drive.google.com/drive/folders/1zxrr31Kp2D_N9a-OUAPaJEn_yTaSXTfZ?usp=drive_link), [RGT-S Models](https://drive.google.com/drive/folders/1j46WHs1Gvyif1SsZXKy1Y1IrQH0gfIQ1?usp=drive_link) - [DCTLSA](https://github.com/zengkun301/DCTLSA) | [Models](https://github.com/zengkun301/DCTLSA/tree/main/pretrained) - [ATD](https://github.com/LabShuHangGU/Adaptive-Token-Dictionary) | [Models](https://drive.google.com/drive/folders/1D3BvTS1xBcaU1mp50k3pBzUWb7qjRvmB?usp=sharing) +- [AdaCode](https://github.com/kechunl/AdaCode) | [Models](https://github.com/kechunl/AdaCode/releases/tag/v0-pretrain_models) #### Face Restoration diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py index bd121e52..2192a922 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py @@ -2,6 +2,7 @@ from .architectures import ( MAT, + AdaCode, CodeFormer, DDColor, FeMaSR, @@ -17,6 +18,7 @@ ArchSupport.from_architecture(CodeFormer.CodeFormerArch()), ArchSupport.from_architecture(MAT.MATArch()), ArchSupport.from_architecture(DDColor.DDColorArch()), + ArchSupport.from_architecture(AdaCode.AdaCodeArch()), ArchSupport.from_architecture(FeMaSR.FeMaSRArch()), ArchSupport.from_architecture(M3SNet.M3SNetArch()), ArchSupport.from_architecture(Restormer.RestormerArch()), diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py new file mode 100644 index 00000000..49a4c89c --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/__init__.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing_extensions import override + +from spandrel import ( + Architecture, + ImageModelDescriptor, + SizeRequirements, + StateDict, +) +from spandrel.util import KeyCondition, get_first_seq_index, get_seq_len + +from .arch.adacode_contrast_arch import AdaCodeSRNet_Contrast as AdaCode + +_inv_channel_query_dict = { + 256: [8, 16, 32, 64], + 128: [128], + 64: [256], + 32: [512], +} + + +def _clean_state_dict(state_dict: StateDict): + # To make my day a little brighter, the pretrained models of FeMaSR have a bunch + # of useless keys in their state dict. With great delight, I saw those keys cause + # errors when calling `model.load_state_dict(state_dict)`, so this function + # removes them. + + keys = list(state_dict.keys()) + for k in keys: + if k.startswith( + ( + "sft_fusion_group.", + "multiscale_encoder.upsampler.", + "conv_semantic.", + "vgg_feat_extractor.", + "weight_cri.", + "contrast_cri.", + ) + ): + del state_dict[k] + + +class AdaCodeArch(Architecture[AdaCode]): + def __init__(self) -> None: + super().__init__( + id="AdaCode", + detect=KeyCondition.has_all( + "multiscale_encoder.in_conv.weight", + "multiscale_encoder.blocks.0.0.weight", + "decoder_group.0.block.1.weight", + "out_conv.weight", + "weight_predictor.blocks.0.swin_blks.0.residual_group.blocks.0.norm1.weight", + "before_quant_group.0.weight", + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[AdaCode]: + _clean_state_dict(state_dict) + + # in_channel = 3 + # codebook_params: list[list[int]] = [[32, 256, 256]] + # gt_resolution = 256 + # LQ_stage = False + # norm_type = "gn" + act_type = "silu" + use_quantize = True # cannot be deduced from state_dict + # scale_factor = 1 + use_residual = True # cannot be deduced from state_dict + + in_channel = state_dict["multiscale_encoder.in_conv.weight"].shape[1] + + # gt_resolution can be derived from the decoders + # we assume that gt_resolution is a power of 2 + max_depth = get_seq_len(state_dict, "decoder_group") + # in the last decoder iteration, we essentially have: + # out_ch = channel_query_dict[gt_resolution] + out_ch = state_dict[f"decoder_group.{max_depth-1}.block.1.weight"].shape[0] + gt_resolution_candidates = _inv_channel_query_dict[out_ch] + # just choose the largest one + gt_resolution = gt_resolution_candidates[-1] + + # the codebook is complex to reconstruct + cb_height = get_seq_len(state_dict, "quantize_group") + codebook_params = [] + for i in range(cb_height): + emb_num = state_dict[f"quantize_group.{i}.embedding.weight"].shape[0] + emb_dim = state_dict[f"quantize_group.{i}.embedding.weight"].shape[1] + + scale_in_ch = state_dict[f"after_quant_group.{i}.weight"].shape[0] + candidates = _inv_channel_query_dict[scale_in_ch] + # we just need *a* scale, so we can pick the first one + codebook_params.append([candidates[0], emb_num, emb_dim]) + + # max_depth = int(log2(gt_resolution // scale_0)) + # We assume that gt_resolution and scale_0 are powers of 2, so we can calculate + # them directly + scale_0 = gt_resolution // (2**max_depth) + codebook_params[0][0] = scale_0 + + # scale factor + swin_block_index = get_first_seq_index( + state_dict, + "multiscale_encoder.blocks.{}.swin_blks.0.residual_group.blocks.0.attn.relative_position_bias_table", + ) + if swin_block_index >= 0: + LQ_stage = True # noqa: N806 + # encode_depth = int(log2(gt_resolution // scale_factor // scale_0)) + encode_depth = swin_block_index + scale_factor = gt_resolution // (2**encode_depth * scale_0) + else: + LQ_stage = False # noqa: N806 + scale_factor = 1 + + if "decoder_group.0.block.2.conv.0.norm.running_mean" in state_dict: + norm_type = "bn" + elif "decoder_group.0.block.2.conv.0.norm.weight" in state_dict: + norm_type = "gn" + else: + # we cannot differentiate between "none" and "in" + norm_type = "in" + + model = AdaCode( + in_channel=in_channel, + codebook_params=codebook_params, + gt_resolution=gt_resolution, + LQ_stage=LQ_stage, + norm_type=norm_type, + act_type=act_type, + use_quantize=use_quantize, + scale_factor=scale_factor, + use_residual=use_residual, + ) + + multiple_of = {2: 32, 4: 16}.get(scale_factor, 8) + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration" if scale_factor == 1 else "SR", + tags=[], + supports_half=True, # TODO + supports_bfloat16=True, # TODO + scale=scale_factor, + input_channels=in_channel, + output_channels=in_channel, + size_requirements=SizeRequirements(multiple_of=multiple_of), + ) diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/arch/LICENSE b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/arch/LICENSE new file mode 100644 index 00000000..93047ee3 --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/arch/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/arch/adacode_contrast_arch.py b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/arch/adacode_contrast_arch.py new file mode 100644 index 00000000..cddffaef --- /dev/null +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/architectures/AdaCode/arch/adacode_contrast_arch.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import numpy as np +import torch +from torch import nn as nn + +from spandrel.util import store_hyperparameters + +from ...FeMaSR.arch.femasr import DecoderBlock, MultiScaleEncoder, SwinLayers + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + def __init__(self, n_e, e_dim): + super().__init__() + self.n_e = int(n_e) + self.e_dim = int(e_dim) + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def dist(self, x, y): + return ( + torch.sum(x**2, dim=1, keepdim=True) + + torch.sum(y**2, dim=1) + - 2 * torch.matmul(x, y.t()) + ) + + def forward(self, z: torch.Tensor): + """ + Args: + z: input features to be quantized, z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + gt_indices: feature map of given indices, used for visualization. + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + + codebook = self.embedding.weight + + d = self.dist(z_flattened, codebook) + + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + min_encodings = torch.zeros( + min_encoding_indices.shape[0], codebook.shape[0] + ).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, codebook) + z_q = z_q.view(z.shape) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class WeightPredictor(nn.Module): + def __init__( + self, + in_channel: int, + cls: int, + weight_softmax=False, + **swin_opts, + ): + super().__init__() + + self.blocks = nn.ModuleList() + self.blocks.append(SwinLayers(**swin_opts)) + # weight + self.blocks.append(nn.Conv2d(in_channel, cls, kernel_size=1)) + if weight_softmax: + self.blocks.append(nn.Softmax(dim=1)) + + def forward(self, input: torch.Tensor): + x = input + for m in self.blocks: + x = m(x) + return x + + +@store_hyperparameters() +class AdaCodeSRNet_Contrast(nn.Module): + hyperparameters = {} + + def __init__( + self, + *, + in_channel=3, + codebook_params: list[list[int]] = [[32, 256, 256]], + gt_resolution=256, + LQ_stage=False, + norm_type="gn", + act_type="silu", + use_quantize=True, + scale_factor=1, + use_residual=True, + weight_softmax=False, + ): + super().__init__() + + codebook_params_np = np.array(codebook_params) + + self.codebook_scale = codebook_params_np[:, 0] + codebook_emb_num = codebook_params_np[:, 1].astype(int) + codebook_emb_dim = codebook_params_np[:, 2].astype(int) + + self.use_quantize = use_quantize + self.in_channel = in_channel + self.gt_res = gt_resolution + self.LQ_stage = LQ_stage + self.scale_factor = scale_factor if LQ_stage else 1 + self.use_residual = use_residual + self.weight_softmax = weight_softmax + + channel_query_dict = { + 8: 256, + 16: 256, + 32: 256, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + } + + # build encoder + self.max_depth = int(np.log2(gt_resolution // self.codebook_scale[0])) + encode_depth = int( + np.log2(gt_resolution // self.scale_factor // self.codebook_scale[0]) + ) + self.multiscale_encoder = MultiScaleEncoder( + in_channel, + encode_depth, + self.gt_res // self.scale_factor, + channel_query_dict, + norm_type, + act_type, + LQ_stage, + ) + + # build decoder + self.decoder_group = nn.ModuleList() + for i in range(self.max_depth): + res = gt_resolution // 2**self.max_depth * 2**i + in_ch, out_ch = channel_query_dict[res], channel_query_dict[res * 2] + self.decoder_group.append(DecoderBlock(in_ch, out_ch, norm_type, act_type)) + + self.out_conv = nn.Conv2d(out_ch, 3, 3, 1, 1) # type: ignore + + # build weight predictor + self.weight_predictor = WeightPredictor( + channel_query_dict[self.codebook_scale[0]], + self.codebook_scale.shape[0], + self.weight_softmax, + ) + + # build multi-scale vector quantizers + self.quantize_group = nn.ModuleList() + self.before_quant_group = nn.ModuleList() + self.after_quant_group = nn.ModuleList() + + for i in range(codebook_params_np.shape[0]): + quantize = VectorQuantizer( + codebook_emb_num[i], + codebook_emb_dim[i], + ) + self.quantize_group.append(quantize) + + quant_in_ch = channel_query_dict[self.codebook_scale[i]] + self.before_quant_group.append( + nn.Conv2d(quant_in_ch, codebook_emb_dim[i], 1) + ) + self.after_quant_group.append( + nn.Conv2d(codebook_emb_dim[i], quant_in_ch, 3, 1, 1) + ) + + def encode_and_decode(self, input): + enc_feats = self.multiscale_encoder(input.detach()) + if self.LQ_stage: + enc_feats = enc_feats[-3:] + else: + enc_feats = enc_feats[::-1] + + after_quant_feat_group = [] + x = enc_feats[0] + for i in range(self.max_depth): + cur_res = self.gt_res // 2**self.max_depth * 2**i + if cur_res in self.codebook_scale: # needs to perform quantize + before_quant_feat = enc_feats[i] + + # quantize features with multiple codebooks + for codebook_idx in range(self.codebook_scale.shape[0]): + feat_to_quant = self.before_quant_group[codebook_idx]( + before_quant_feat + ) + + z_quant = self.quantize_group[codebook_idx](feat_to_quant) + + if not self.use_quantize: + z_quant = feat_to_quant + + after_quant_feat = self.after_quant_group[codebook_idx](z_quant) + after_quant_feat_group.append(after_quant_feat) + + # merge feature tensors + weight = self.weight_predictor(before_quant_feat).unsqueeze( + 2 + ) # B x N x 1 x H x W + x = torch.sum( + torch.mul( + torch.stack(after_quant_feat_group).transpose(0, 1), weight + ), + dim=1, + ) + else: + if self.LQ_stage and self.use_residual: + x = x + enc_feats[i] + else: + x = x + + x = self.decoder_group[i](x) + + out_img = self.out_conv(x) + + return out_img + + def forward(self, input): + # in HQ stage, or LQ test stage, no GT indices needed. + return self.encode_and_decode(input) diff --git a/tests/__snapshots__/test_AdaCode.ambr b/tests/__snapshots__/test_AdaCode.ambr new file mode 100644 index 00000000..69dc4b9c --- /dev/null +++ b/tests/__snapshots__/test_AdaCode.ambr @@ -0,0 +1,37 @@ +# serializer version: 1 +# name: test_AdaCode_SR_X2_model_g + ImageModelDescriptor( + architecture=AdaCodeArch( + id='AdaCode', + name='AdaCode', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=2, + size_requirements=SizeRequirements(minimum=0, multiple_of=32, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + ]), + tiling=, + ) +# --- +# name: test_AdaCode_SR_X4_model_g + ImageModelDescriptor( + architecture=AdaCodeArch( + id='AdaCode', + name='AdaCode', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=4, + size_requirements=SizeRequirements(minimum=0, multiple_of=16, square=False), + supports_bfloat16=True, + supports_half=True, + tags=list([ + ]), + tiling=, + ) +# --- diff --git a/tests/images/outputs/16x16/AdaCode_SR_X4_model_g.png b/tests/images/outputs/16x16/AdaCode_SR_X4_model_g.png new file mode 100644 index 00000000..3d0877ef Binary files /dev/null and b/tests/images/outputs/16x16/AdaCode_SR_X4_model_g.png differ diff --git a/tests/images/outputs/32x32/AdaCode_SR_X2_model_g.png b/tests/images/outputs/32x32/AdaCode_SR_X2_model_g.png new file mode 100644 index 00000000..3345c406 Binary files /dev/null and b/tests/images/outputs/32x32/AdaCode_SR_X2_model_g.png differ diff --git a/tests/images/outputs/32x32/AdaCode_SR_X4_model_g.png b/tests/images/outputs/32x32/AdaCode_SR_X4_model_g.png new file mode 100644 index 00000000..87368c2b Binary files /dev/null and b/tests/images/outputs/32x32/AdaCode_SR_X4_model_g.png differ diff --git a/tests/images/outputs/64x64/AdaCode_SR_X2_model_g.png b/tests/images/outputs/64x64/AdaCode_SR_X2_model_g.png new file mode 100644 index 00000000..3ce99edd Binary files /dev/null and b/tests/images/outputs/64x64/AdaCode_SR_X2_model_g.png differ diff --git a/tests/images/outputs/64x64/AdaCode_SR_X4_model_g.png b/tests/images/outputs/64x64/AdaCode_SR_X4_model_g.png new file mode 100644 index 00000000..95139d62 Binary files /dev/null and b/tests/images/outputs/64x64/AdaCode_SR_X4_model_g.png differ diff --git a/tests/test_AdaCode.py b/tests/test_AdaCode.py new file mode 100644 index 00000000..bf202b31 --- /dev/null +++ b/tests/test_AdaCode.py @@ -0,0 +1,75 @@ +from spandrel_extra_arches.architectures.AdaCode import AdaCode, AdaCodeArch +from tests.test_GFPGAN import disallowed_props + +from .util import ( + ModelFile, + TestImage, + assert_image_inference, + assert_loads_correctly, + assert_size_requirements, +) + + +def test_load(): + assert_loads_correctly( + AdaCodeArch(), + lambda: AdaCode(), + lambda: AdaCode(in_channel=1), + lambda: AdaCode(in_channel=4), + lambda: AdaCode(gt_resolution=512), + lambda: AdaCode(gt_resolution=128), + lambda: AdaCode(LQ_stage=True, scale_factor=2), + lambda: AdaCode(LQ_stage=True, scale_factor=4), + lambda: AdaCode(LQ_stage=True, scale_factor=8), + lambda: AdaCode(norm_type="gn"), + lambda: AdaCode(norm_type="bn"), + lambda: AdaCode(norm_type="in"), + # lambda: AdaCode(weight_softmax=True), + lambda: AdaCode(codebook_params=[[32, 1024, 512]]), + lambda: AdaCode(codebook_params=[[32, 512, 256]]), + lambda: AdaCode(codebook_params=[[64, 512, 256], [32, 1024, 512]]), + ignore_parameters={ + # there are multiple equivalent codebook_params for some models + "codebook_params" + }, + ) + + +def test_size_requirements(): + file = ModelFile.from_url( + "https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X2_model_g.pth" + ) + assert_size_requirements(file.load_model()) + + file = ModelFile.from_url( + "https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X4_model_g.pth" + ) + assert_size_requirements(file.load_model()) + + +def test_AdaCode_SR_X2_model_g(snapshot): + file = ModelFile.from_url( + "https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X2_model_g.pth" + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, AdaCode) + assert_image_inference( + file, + model, + [TestImage.SR_32, TestImage.SR_64], + ) + + +def test_AdaCode_SR_X4_model_g(snapshot): + file = ModelFile.from_url( + "https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X4_model_g.pth" + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, AdaCode) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + )