Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDNN] Support gradient kernels #9986

Merged
merged 21 commits into from
Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
460 changes: 416 additions & 44 deletions python/tvm/contrib/cudnn.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,10 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
reg.register_injective_schedule("nn.batch_to_space_nd")


reg.register_strategy("nn.conv2d_backward_weight", strategy.conv2d_backward_weight_strategy)
reg.register_pattern("nn.conv2d_backward_weight", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_legalize("nn.conv2d_backward_weight")
def legalize_conv2d_backward_weight(attrs, inputs, types):
"""Legalize conv2d_backward_weight op.
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,25 @@ def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@conv2d_backward_weight_strategy.register(["cuda"])
def conv2d_backward_weight_strategy_cuda(attrs, inputs, out_type, target):
"""conv2d_backward_weight cuda strategy"""
strategy = _op.OpStrategy()
if target.kind.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(
wrap_compute_conv2d_backward_weight(topi.cuda.conv2d_backward_weight_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="conv2d_backward_weight_strategy.cudnn",
plevel=15,
)
else:
raise RuntimeError(
"conv2d_backward_weight on cuda is currently only supported with cudnn. "
"Please run Legalize pass to decompose this op into supported ops."
)
return strategy


@conv2d_transpose_strategy.register(["cuda", "gpu"])
def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
"""conv2d_transpose cuda strategy"""
Expand All @@ -579,6 +598,15 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.cuda",
)

if target.kind.name == "cuda" and "cudnn" in target.libs and attrs.kernel_layout == "IOHW":
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_cudnn),
wrap_topi_schedule(topi.generic.schedule_extern),
name="conv2d_transpose.cudnn.cuda",
plevel=25,
)
# TODO(masahi): Support conv2d_transpose NHWC.
return strategy


Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,3 +1841,41 @@ def einsum_strategy(attrs, inputs, out_type, target):
name="einsum.generic",
)
return strategy


# conv2d_backward_weight
def wrap_compute_conv2d_backward_weight(topi_compute):
"""wrap conv2d_backward_weight topi compute"""

def _compute_conv2d_backward_weight(attrs, inputs, out_dtype):
kernel_size = get_const_tuple(attrs.kernel_size)
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
out_dtype = attrs.out_dtype
layout = attrs.data_layout
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
out = topi_compute(
inputs[0],
inputs[1],
kernel_size,
padding,
strides,
dilation,
groups,
layout,
out_dtype,
)
return [out]

return _compute_conv2d_backward_weight


@override_native_generic_func("conv2d_backward_weight_strategy")
def conv2d_backward_weight_strategy(attrs, inputs, out_type, target):
"""wgrad generic strategy"""
raise RuntimeError(
"conv2d_backward_weight is currently only supported with cudnn. "
"Please run Legalize pass to decompose this op into supported ops."
)
19 changes: 19 additions & 0 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,22 @@ def conv2d_cudnn(
def schedule_conv2d_cudnn(cfg, outs):
"""Create the schedule for conv2d_cudnn"""
return generic.schedule_extern(outs)


def conv2d_backward_weight_cudnn(
dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype
):
"""Compute conv2d wgrad using CuDNN library"""
assert layout in ["NCHW", "NHWC"]
return cudnn.conv_backward_filter(
dy,
x,
kernel_size,
padding,
stride,
dilation,
conv_mode=1,
tensor_format=0 if layout == "NCHW" else 1,
conv_dtype=output_dtype,
groups=groups,
)
8 changes: 8 additions & 0 deletions python/tvm/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import tvm
from tvm import te
from tvm.contrib import cudnn
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn
Expand Down Expand Up @@ -286,3 +287,10 @@ def _callback(op):
traverse_inline(s, outs[0].op, _callback)

return s


def conv2d_transpose_cudnn(x, w, stride, padding, out_dtype, output_padding=(0, 0)):
"""Compute conv2d_tranpose using cudnn dgrad kernel"""
return cudnn.conv_backward_data(
x, w, padding, stride, (1, 1), 1, 0, out_dtype, groups=1, output_padding=output_padding
)
1 change: 0 additions & 1 deletion python/tvm/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def conv2d_transpose_legalize(attrs, inputs, types):
result : tvm.relay.Expr
The legalized expr
"""

data, kernel = inputs
kernel_layout = attrs["kernel_layout"]
if attrs["data_layout"] == "NHWC":
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@
from .nll_loss import nll_loss
from .dense import dense
from .searchsorted import searchsorted_ref
from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python
from .conv2d_backcward_weight_python import conv2d_backward_weight_python
44 changes: 43 additions & 1 deletion python/tvm/topi/testing/conv2d_backcward_weight_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding

Returns
-------
b_np : np.ndarray
dw_np : np.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]

"""
Expand Down Expand Up @@ -74,3 +74,45 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
dw[k, c, r, s] = acc

return dw


def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"):
"""Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.

Parameters
----------
dy_np : numpy.ndarray
4-D with shape [batch, in_channel, out_height, out_width] for NCHW layout

x_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width] for NCHW layout

kernel_size : tuple of two ints
Height and width of the weight

stride : tuple of two ints
Stride size, or [stride_height, stride_width]

padding : tuple of two ints
Spatial padding, or [pad_h, pad_w]

layout: string
Layout of dy_np and x_np

Returns
-------
dw_np : np.ndarray
Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
[num_filter, filter_height, filter_width, in_channel] for NHWC layout.
"""
if layout == "NCHW":
return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding)

dw_np_oihw = conv2d_backward_weight_nchw_python(
np.transpose(dy_np, [0, 3, 1, 2]),
np.transpose(x_np, [0, 3, 1, 2]),
kernel_size,
stride,
padding,
)
return np.transpose(dw_np_oihw, [0, 2, 3, 1])
4 changes: 2 additions & 2 deletions python/tvm/topi/testing/conv2d_transpose_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
dilated_a_np.shape[2] + bpad_top + bpad_bottom,
dilated_a_np.shape[3] + bpad_left + bpad_right,
)
)
).astype(a_np.dtype)
padded_a_np[
:,
:,
Expand All @@ -83,7 +83,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
# convolution stage
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w
b_np = np.zeros((batch, out_c, out_h, out_w))
b_np = np.zeros((batch, out_c, out_h, out_w)).astype(a_np.dtype)
for n in range(batch):
for f in range(out_c):
for c in range(in_c):
Expand Down
1 change: 0 additions & 1 deletion src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,6 @@ given the original input data and the output gradient.
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

} // namespace relay
Expand Down
Loading