Skip to content

Commit

Permalink
[CUDNN] Support gradient kernels (apache#9986)
Browse files Browse the repository at this point in the history
* Dgrad nchw, nhwc, fp16 working

commit 426e5dc
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:48:53 2022 +0900

    black

commit 211a58b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:43:52 2022 +0900

    fp16 also works

commit c2a34d4
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:36:36 2022 +0900

    nhwc test also worked

commit c0609ab
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:21:23 2022 +0900

    nchw test worked

commit 2bf68c7
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 10:41:35 2022 +0900

    add test stub

commit c86b128
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 10:32:09 2022 +0900

    add python definition stub

commit 3166952
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:57:18 2022 +0900

    bwd filter compiled

commit e311ba3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:27:55 2022 +0900

    dgrad compiled

commit 47f35be
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:16:43 2022 +0900

    add dgrad stub

commit ebed032
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 17:01:56 2022 +0900

    cpplint

commit 834f54a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 16:55:58 2022 +0900

    remove cudnn get output

commit dcbd9c9
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 16:28:07 2022 +0900

    more refactor

commit 146464e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 15:57:35 2022 +0900

    Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc

* add python function for cudnn wgrad

* adding wgrad test

* black

* wgrad nchw and nhwc worked

* remove bwd algo name stuff

* compute output shape properly

* swap arg order in wgrad

* add kernel size arg in test

* black

* cleanup

* more fix

* fix dgrad test

* support running relay conv2d_backward_weight directly with cudnn

* black

* refactor reference function to support nhwc

* removed unused function

* lint

* enable offloading conv2d_transpose to cudnn dgrad

* relax tol

* name fix, remove print
  • Loading branch information
masahi authored and yuanfz98 committed Jan 24, 2022
1 parent 02571b9 commit 7b76524
Show file tree
Hide file tree
Showing 17 changed files with 996 additions and 74 deletions.
460 changes: 416 additions & 44 deletions python/tvm/contrib/cudnn.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,10 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
reg.register_injective_schedule("nn.batch_to_space_nd")


reg.register_strategy("nn.conv2d_backward_weight", strategy.conv2d_backward_weight_strategy)
reg.register_pattern("nn.conv2d_backward_weight", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_legalize("nn.conv2d_backward_weight")
def legalize_conv2d_backward_weight(attrs, inputs, types):
"""Legalize conv2d_backward_weight op.
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,25 @@ def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@conv2d_backward_weight_strategy.register(["cuda"])
def conv2d_backward_weight_strategy_cuda(attrs, inputs, out_type, target):
"""conv2d_backward_weight cuda strategy"""
strategy = _op.OpStrategy()
if target.kind.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(
wrap_compute_conv2d_backward_weight(topi.cuda.conv2d_backward_weight_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="conv2d_backward_weight_strategy.cudnn",
plevel=15,
)
else:
raise RuntimeError(
"conv2d_backward_weight on cuda is currently only supported with cudnn. "
"Please run Legalize pass to decompose this op into supported ops."
)
return strategy


@conv2d_transpose_strategy.register(["cuda", "gpu"])
def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
"""conv2d_transpose cuda strategy"""
Expand All @@ -579,6 +598,15 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.cuda",
)

if target.kind.name == "cuda" and "cudnn" in target.libs and attrs.kernel_layout == "IOHW":
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="conv2d_transpose.cudnn.cuda",
plevel=25,
)
# TODO(masahi): Support conv2d_transpose NHWC.
return strategy


Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,3 +1841,41 @@ def einsum_strategy(attrs, inputs, out_type, target):
name="einsum.generic",
)
return strategy


# conv2d_backward_weight
def wrap_compute_conv2d_backward_weight(topi_compute):
"""wrap conv2d_backward_weight topi compute"""

def _compute_conv2d_backward_weight(attrs, inputs, out_dtype):
kernel_size = get_const_tuple(attrs.kernel_size)
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
out_dtype = attrs.out_dtype
layout = attrs.data_layout
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
out = topi_compute(
inputs[0],
inputs[1],
kernel_size,
padding,
strides,
dilation,
groups,
layout,
out_dtype,
)
return [out]

return _compute_conv2d_backward_weight


@override_native_generic_func("conv2d_backward_weight_strategy")
def conv2d_backward_weight_strategy(attrs, inputs, out_type, target):
"""wgrad generic strategy"""
raise RuntimeError(
"conv2d_backward_weight is currently only supported with cudnn. "
"Please run Legalize pass to decompose this op into supported ops."
)
19 changes: 19 additions & 0 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,22 @@ def conv2d_cudnn(
def schedule_conv2d_cudnn(cfg, outs):
"""Create the schedule for conv2d_cudnn"""
return generic.schedule_extern(outs)


def conv2d_backward_weight_cudnn(
dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype
):
"""Compute conv2d wgrad using CuDNN library"""
assert layout in ["NCHW", "NHWC"]
return cudnn.conv_backward_filter(
dy,
x,
kernel_size,
padding,
stride,
dilation,
conv_mode=1,
tensor_format=0 if layout == "NCHW" else 1,
conv_dtype=output_dtype,
groups=groups,
)
8 changes: 8 additions & 0 deletions python/tvm/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import tvm
from tvm import te
from tvm.contrib import cudnn
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn
Expand Down Expand Up @@ -286,3 +287,10 @@ def _callback(op):
traverse_inline(s, outs[0].op, _callback)

return s


def conv2d_transpose_cudnn(x, w, stride, padding, out_dtype, output_padding=(0, 0)):
"""Compute conv2d_tranpose using cudnn dgrad kernel"""
return cudnn.conv_backward_data(
x, w, padding, stride, (1, 1), 1, 0, out_dtype, groups=1, output_padding=output_padding
)
1 change: 0 additions & 1 deletion python/tvm/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def conv2d_transpose_legalize(attrs, inputs, types):
result : tvm.relay.Expr
The legalized expr
"""

data, kernel = inputs
kernel_layout = attrs["kernel_layout"]
if attrs["data_layout"] == "NHWC":
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@
from .nll_loss import nll_loss
from .dense import dense
from .searchsorted import searchsorted_ref
from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python
from .conv2d_backcward_weight_python import conv2d_backward_weight_python
44 changes: 43 additions & 1 deletion python/tvm/topi/testing/conv2d_backcward_weight_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
Returns
-------
b_np : np.ndarray
dw_np : np.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]
"""
Expand Down Expand Up @@ -74,3 +74,45 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
dw[k, c, r, s] = acc

return dw


def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"):
"""Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.
Parameters
----------
dy_np : numpy.ndarray
4-D with shape [batch, in_channel, out_height, out_width] for NCHW layout
x_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width] for NCHW layout
kernel_size : tuple of two ints
Height and width of the weight
stride : tuple of two ints
Stride size, or [stride_height, stride_width]
padding : tuple of two ints
Spatial padding, or [pad_h, pad_w]
layout: string
Layout of dy_np and x_np
Returns
-------
dw_np : np.ndarray
Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
[num_filter, filter_height, filter_width, in_channel] for NHWC layout.
"""
if layout == "NCHW":
return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding)

dw_np_oihw = conv2d_backward_weight_nchw_python(
np.transpose(dy_np, [0, 3, 1, 2]),
np.transpose(x_np, [0, 3, 1, 2]),
kernel_size,
stride,
padding,
)
return np.transpose(dw_np_oihw, [0, 2, 3, 1])
4 changes: 2 additions & 2 deletions python/tvm/topi/testing/conv2d_transpose_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
dilated_a_np.shape[2] + bpad_top + bpad_bottom,
dilated_a_np.shape[3] + bpad_left + bpad_right,
)
)
).astype(a_np.dtype)
padded_a_np[
:,
:,
Expand All @@ -83,7 +83,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
# convolution stage
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w
b_np = np.zeros((batch, out_c, out_h, out_w))
b_np = np.zeros((batch, out_c, out_h, out_w)).astype(a_np.dtype)
for n in range(batch):
for f in range(out_c):
for c in range(in_c):
Expand Down
1 change: 0 additions & 1 deletion src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,6 @@ given the original input data and the output gradient.
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

} // namespace relay
Expand Down
Loading

0 comments on commit 7b76524

Please sign in to comment.