Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support conv2d transpose operator with groups > 1 #8182

Closed
YuhengHuang42 opened this issue Jun 3, 2021 · 3 comments
Closed

Support conv2d transpose operator with groups > 1 #8182

YuhengHuang42 opened this issue Jun 3, 2021 · 3 comments
Labels
frontend:pytorch python/tvm/relay/frontend/torch

Comments

@YuhengHuang42
Copy link
Contributor

YuhengHuang42 commented Jun 3, 2021

Currently TVM doesn't support Conv2d transpose with groups > 1. When turning Pytorch model to Relay there will be an error.

Detail discussion could be seen at: https://discuss.tvm.apache.org/t/pytorch-error-when-turning-a-simple-pytorch-model-to-relay/10158

P.S. It may be better if TVM can show some warnings saying that a specific parameter for an op is unsupported, rather than just gives some kind of error?

@Lyken17
Copy link
Contributor

Lyken17 commented Jul 10, 2021

happened to be here. It was surprising the issue was not solved after so long time. TranposedConv2d is an important operator in GAN related applications. For anyone who also had this trouble, you may refer following code to support group parameter in transposed convolution

    out_c = simplify(out_c)
    out_channels = simplify(out_c * groups)

    out_h = simplify(in_h - filter_h + 1)
    out_w = simplify(in_w - filter_w + 1)
    dc = te.reduce_axis((0, in_c // groups), name="dc")
    dh = te.reduce_axis((0, filter_h), name="dh")
    dw = te.reduce_axis((0, filter_w), name="dw")
    
    # data: batch, in_channels, out_h, out_w
    # weight: out_channels // G, in_channels, out_h, out_w
    return te.compute(
        (batch, out_channels, out_h, out_w),
        lambda b, c, h, w: te.sum(
            data_pad[
                b, 
                c // (out_channels // groups) * (in_channels // groups) + dc, 
                h + dh, 
                w + dw
            ].astype(out_dtype)
            * kernel_transform[
                c % (out_channels // groups), 
                c // (out_channels // groups) * (in_channels // groups) + dc, 
                dh, 
                dw
            ].astype(out_dtype),
            axis=[dc, dh, dw],
        ),
        tag="conv2d_transpose_nchw",
    )

@tqchen
Copy link
Member

tqchen commented Jul 20, 2021

Thanks @Lyken17 , do you mind to send a PR to resolve the case?

@Lyken17
Copy link
Contributor

Lyken17 commented Jul 21, 2021

@tqchen Sure, I can craft a PR.

@areusch areusch added the needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it label Oct 19, 2022
@hpanda-naut hpanda-naut added frontend:pytorch python/tvm/relay/frontend/torch and removed needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Nov 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend:pytorch python/tvm/relay/frontend/torch
Projects
None yet
Development

No branches or pull requests

5 participants