Skip to content

Commit

Permalink
dw conv2d properly supported for wgrad
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 6, 2022
1 parent adc4e22 commit 446a95b
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 13 deletions.
14 changes: 12 additions & 2 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,20 @@ def conv_backward_filter(
x.shape[0], tvm.tir.expr.IntImm
), "Dynamic batch is not supported for cudnn conv2d backwad filter yet."

ic_ind = 1 if tensor_format == 0 else 3

if groups > 1:
assert (
x_shape[ic_ind] == dy.shape[ic_ind] and x_shape[ic_ind] == groups
), "Only depthwise wgrad supported for groups > 1."
ic = 1
else:
ic = x_shape[ic_ind]

if tensor_format == 0:
dw_shape = [dy.shape[1], x_shape[1], filter_h, filter_w]
dw_shape = [dy.shape[1], ic, filter_h, filter_w]
else:
dw_shape = [dy.shape[3], filter_h, filter_w, x_shape[3]]
dw_shape = [dy.shape[3], filter_h, filter_w, ic]

algo = conv_backward_filter_find_algo(
tensor_format,
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,6 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
dilation=attrs.strides,
groups=in_channel * batch,
out_dtype=attrs.out_dtype,
channels=attrs.channels,
)

# infer shape of backward_weight
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ def conv2d_backward_weight_cudnn(
):
"""Compute conv2d wgrad using CuDNN library"""
assert layout in ["NCHW", "NHWC"]
# cuDNN does not seem to support other combination.
assert output_dtype == "float16", "Only supports fp16 output for cuDNN wgrad."
conv_dtype = "float32"

if dy.dtype == "float16":
# cuDNN does not seem to support other combination.
assert output_dtype == "float16", "Only supports fp16 output for cuDNN fp16 wgrad."

conv_dtype = "float32" # Accumulation is always fp32
return cudnn.conv_backward_filter(
dy,
x,
Expand Down
8 changes: 5 additions & 3 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Att

auto in_channels = dshape_nchw[1];
auto out_channels = grad_shape_nchw[1];

auto in_channels_intimm = in_channels.as<IntImmNode>();
auto out_channels_intimm = out_channels.as<IntImmNode>();
ICHECK(in_channels_intimm);
Expand All @@ -653,10 +654,11 @@ bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Att
weight_dim_i = indexdiv(in_channels, param->groups);
}

Array<IndexExpr> wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0],
param->kernel_size[1]};
Array<IndexExpr> wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0], param->kernel_size[1]};
auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw);
reporter->Assign(types[2], TensorType(wshape, param->out_dtype));

const auto dw_dtype = param->out_dtype == DataType() ? grad->dtype : param->out_dtype;
reporter->Assign(types[2], TensorType(wshape, dw_dtype));
return true;
}

Expand Down
1 change: 0 additions & 1 deletion tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,4 +932,3 @@ def test_conv2d_bwd():

if __name__ == "__main__":
pytest.main([__file__])
# test_conv2d_backward_weight()
9 changes: 6 additions & 3 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,10 @@ def verify_conv2d_backward_weight(
kernel_size=kernel_size,
groups=groups,
channels=out_channels,
out_dtype=dtype,
),
)

dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize())

for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda -libs=cudnn")]:
Expand All @@ -266,12 +268,13 @@ def verify_conv2d_backward_weight(


def test_conv2d_backward_weight():
verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1))
verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0))
# verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1))
# verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0))
verify_conv2d_backward_weight(
(1, 16, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=16, out_channels=16
)


if __name__ == "__main__":
pytest.main([__file__])
# pytest.main([__file__])
test_conv2d_backward_weight()

0 comments on commit 446a95b

Please sign in to comment.