From 7b765244cd5a2d6360a1f4f12b72f47e7fccd859 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 23 Jan 2022 06:58:31 +0900 Subject: [PATCH] [CUDNN] Support gradient kernels (#9986) * Dgrad nchw, nhwc, fp16 working commit 426e5dca446a27da49270f45171b58f1bfa21fa9 Author: Masahiro Masuda Date: Tue Jan 18 11:48:53 2022 +0900 black commit 211a58b80f4d0f0b5b0230720e41f35e50cb1eaf Author: Masahiro Masuda Date: Tue Jan 18 11:43:52 2022 +0900 fp16 also works commit c2a34d473b063873628bff00e51a44cd8e4d0e4f Author: Masahiro Masuda Date: Tue Jan 18 11:36:36 2022 +0900 nhwc test also worked commit c0609ab147fef30c230a94d16b6c1ba35f7dd9c0 Author: Masahiro Masuda Date: Tue Jan 18 11:21:23 2022 +0900 nchw test worked commit 2bf68c72763708151e9f49f09916a210b2547be8 Author: Masahiro Masuda Date: Tue Jan 18 10:41:35 2022 +0900 add test stub commit c86b1288d5e371f12cba4e1b1866966cb9264401 Author: Masahiro Masuda Date: Tue Jan 18 10:32:09 2022 +0900 add python definition stub commit 3166952f9673376801bf4b5b39eeb6f89452f30a Author: Masahiro Masuda Date: Tue Jan 18 06:57:18 2022 +0900 bwd filter compiled commit e311ba3d05c5f9424ecb952cb5a520ce81a0828a Author: Masahiro Masuda Date: Tue Jan 18 06:27:55 2022 +0900 dgrad compiled commit 47f35beb5eeeb7cbf9f6ec7cf8f5c80c65e8da46 Author: Masahiro Masuda Date: Tue Jan 18 06:16:43 2022 +0900 add dgrad stub commit ebed032d15b1c3895f541c46ce5d80b6dd769034 Author: Masahiro Masuda Date: Mon Jan 17 17:01:56 2022 +0900 cpplint commit 834f54a8c13512130e7d91ca0f54268dc06c5481 Author: Masahiro Masuda Date: Mon Jan 17 16:55:58 2022 +0900 remove cudnn get output commit dcbd9c95fdb8ffef9db9c2350430b270461a31c3 Author: Masahiro Masuda Date: Mon Jan 17 16:28:07 2022 +0900 more refactor commit 146464e8496fff972bdb1687c4e9d432fe3278d5 Author: Masahiro Masuda Date: Mon Jan 17 15:57:35 2022 +0900 Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc * add python function for cudnn wgrad * adding wgrad test * black * wgrad nchw and nhwc worked * remove bwd algo name stuff * compute output shape properly * swap arg order in wgrad * add kernel size arg in test * black * cleanup * more fix * fix dgrad test * support running relay conv2d_backward_weight directly with cudnn * black * refactor reference function to support nhwc * removed unused function * lint * enable offloading conv2d_transpose to cudnn dgrad * relax tol * name fix, remove print --- python/tvm/contrib/cudnn.py | 460 ++++++++++++++++-- python/tvm/relay/op/nn/_nn.py | 4 + python/tvm/relay/op/strategy/cuda.py | 28 ++ python/tvm/relay/op/strategy/generic.py | 38 ++ python/tvm/topi/cuda/conv2d.py | 19 + python/tvm/topi/cuda/conv2d_transpose_nchw.py | 8 + python/tvm/topi/nn/conv2d_transpose.py | 1 - python/tvm/topi/testing/__init__.py | 2 +- .../testing/conv2d_backcward_weight_python.py | 44 +- .../topi/testing/conv2d_transpose_python.py | 4 +- src/relay/op/nn/convolution.cc | 1 - src/runtime/contrib/cudnn/conv_backward.cc | 265 ++++++++++ src/runtime/contrib/cudnn/conv_forward.cc | 4 +- src/runtime/contrib/cudnn/cudnn_utils.h | 4 +- tests/python/contrib/test_cudnn.py | 138 ++++++ tests/python/relay/test_op_grad_level2.py | 33 +- tests/python/relay/test_op_level2.py | 17 +- 17 files changed, 996 insertions(+), 74 deletions(-) create mode 100644 src/runtime/contrib/cudnn/conv_backward.cc diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 9b92c7cc2773..c897de74b250 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -36,33 +36,6 @@ "CUDNN_CONVOLUTION_FWD_ALGO_COUNT", ] -_BWD_FILTER_ALGOS = [ - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0", - # non-deterministic - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1", - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT", - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3", - # non-deterministic, algo0 with workspaceS - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD", - # not implemented - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED", - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING", - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT", -] - -_BWD_DATA_ALGOS = [ - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", - # non-deterministic - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT", -] - -_ALGO_TYPE = ["fwd", "bwd_filter", "bwd_data"] - def exists(): """ @@ -285,7 +258,74 @@ def conv_output_shape( return output -def conv_find_algo( +def conv_dgrad_shape( + tensor_format, pad, stride, dilation, dy_shape, w_shape, output_padding=(0, 0) +): + """Get output shape of conv2d gradient with respect to data + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + dy_shape: list + output gradient shape + w_shape: list + weight shape + data_dtype: str + data type + conv_dtype: str + convolution type + groups: int + number of groups + + Returns + ------- + oshape: list + output shape + """ + + assert len(dy_shape) == len(w_shape) + assert len(dy_shape) == 4 + + if tensor_format == 0: + N = dy_shape[0] + C = w_shape[1] + dy_shape = dy_shape[2:] + w_shape = w_shape[2:] + elif tensor_format == 1: + N = dy_shape[0] + C = w_shape[-1] + dy_shape = dy_shape[1:-1] + w_shape = w_shape[1:-1] + else: + raise ValueError("Unsupported CuDNN tensor format: '{}'".format(tensor_format)) + + input_dims = [] + for dy_shape_i, w_shape_i, pad_i, stride_i, dilation_i, out_pad in zip( + dy_shape, w_shape, pad, stride, dilation, output_padding + ): + input_dim = ( + (dy_shape_i - 1) * stride_i - 2 * pad_i + (((w_shape_i - 1) * dilation_i) + 1) + out_pad + ) + input_dims.append(input_dim) + + if tensor_format == 0: + output = [N, C, *input_dims] + else: + output = [N, *input_dims, C] + + return output + + +def _conv_find_algo( + func_name, tensor_format, pad, stride, @@ -297,7 +337,46 @@ def conv_find_algo( conv_dtype, groups=1, ): - """Choose the best algo for the given input. + """ + Common function to choose the best cudnn convolution algorithm for the given input + and the convolution type. + """ + dims = len(x_shape) + assert dims in (4, 5) + + pad, stride, dilation, xshape, wshape = _prepare_global_func_params( + dims - 2, pad, stride, dilation, x_shape, w_shape + ) + yshape = np.array(y_shape, dtype=np.int32) + func = tvm._ffi.get_global_func(func_name) + return func( + tensor_format, + dims - 2, + _get_np_int32_array_handle(pad), + _get_np_int32_array_handle(stride), + _get_np_int32_array_handle(dilation), + _get_np_int32_array_handle(xshape), + _get_np_int32_array_handle(wshape), + _get_np_int32_array_handle(yshape), + data_dtype, + conv_dtype, + groups, + ) + + +def conv_forward_find_algo( + tensor_format, + pad, + stride, + dilation, + x_shape, + w_shape, + y_shape, + data_dtype, + conv_dtype, + groups=1, +): + """Choose the best forward algorithm for the given input. Paramters --------- @@ -329,23 +408,133 @@ def conv_find_algo( algo: int algo chosen by CUDNN """ - dims = len(x_shape) - assert dims in (4, 5) + return _conv_find_algo( + "tvm.contrib.cudnn.conv.forward_find_algo", + tensor_format, + pad, + stride, + dilation, + x_shape, + w_shape, + y_shape, + data_dtype, + conv_dtype, + groups, + ) - pad, stride, dilation, xshape, wshape = _prepare_global_func_params( - dims - 2, pad, stride, dilation, x_shape, w_shape + +def conv_backward_data_find_algo( + tensor_format, + pad, + stride, + dilation, + dy_shape, + w_shape, + dx_shape, + data_dtype, + conv_dtype, + groups=1, +): + """Choose the best backward data algorithm for the given input. + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + 2: CUDNN_TENSOR_NCHW_VECT_C + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + dy_shape: list + output gradient shape + w_shape: list + weight shape + dx_shape: list + dgrad shape + data_dtype: str + data type + conv_dtype: str + convolution type + groups: int + number of groups + + Returns + ------- + algo: int + algo chosen by CUDNN + """ + return _conv_find_algo( + "tvm.contrib.cudnn.conv.backward_data_find_algo", + tensor_format, + pad, + stride, + dilation, + dy_shape, + w_shape, + dx_shape, + data_dtype, + conv_dtype, + groups, ) - yshape = np.array(y_shape, dtype=np.int32) - func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.find_algo") - return func( + + +def conv_backward_filter_find_algo( + tensor_format, + pad, + stride, + dilation, + dy_shape, + x_shape, + dw_shape, + data_dtype, + conv_dtype, + groups=1, +): + """Choose the best backward filter algorithm for the given input. + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + 2: CUDNN_TENSOR_NCHW_VECT_C + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + dy_shape: list + output gradient shape + x_shape: list + weight shape + dw_shape: list + wgrad shape + data_dtype: str + data type + conv_dtype: str + convolution type + groups: int + number of groups + + Returns + ------- + algo: int + algo chosen by CUDNN + """ + return _conv_find_algo( + "tvm.contrib.cudnn.conv.backward_filter_find_algo", tensor_format, - dims - 2, - _get_np_int32_array_handle(pad), - _get_np_int32_array_handle(stride), - _get_np_int32_array_handle(dilation), - _get_np_int32_array_handle(xshape), - _get_np_int32_array_handle(wshape), - _get_np_int32_array_handle(yshape), + pad, + stride, + dilation, + dy_shape, + x_shape, + dw_shape, data_dtype, conv_dtype, groups, @@ -414,7 +603,7 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co if tensor_format == 1 and conv_dtype == "int32": algo = 1 else: - algo = conv_find_algo( + algo = conv_forward_find_algo( tensor_format, pad, stride, @@ -496,6 +685,189 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co ) +def conv_backward_data( + dy, + w, + pad, + stride, + dilation, + conv_mode, + tensor_format, + conv_dtype, + groups=1, + output_padding=(0, 0), +): + """Create a CuDNN extern op that computes the gradient of 2D convolution with respect to data. + + Parameters + ---------- + dy: Tensor + output gradient + w: Tensor + convolution weight + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + conv_mode: int + 0: CUDNN_CONVOLUTION + 1: CUDNN_CROSS_CORRELATION + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + conv_dtype: str + convolution type + groups: int + the number of groups + + Returns + ------- + dx: Tensor + dgrad tensor + """ + dims = len(dy.shape) + assert dims == 4 + + conv_dtype = dy.dtype if conv_dtype is None else conv_dtype + pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) + + assert isinstance( + dy.shape[0], tvm.tir.expr.IntImm + ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." + + dx_shape = conv_dgrad_shape( + tensor_format, pad, stride, dilation, dy.shape, w.shape, output_padding + ) + + algo = conv_backward_data_find_algo( + tensor_format, + pad, + stride, + dilation, + list(dy.shape), + list(w.shape), + dx_shape, + dy.dtype, + conv_dtype, + groups, + ) + + return te.extern( + dx_shape, + [dy, w], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.conv2d.backward_data", + conv_mode, + tensor_format, + algo, + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + ins[0], + ins[1], + outs[0], + conv_dtype, + groups, + ), + name="dx", + ) + + +def conv_backward_filter( + dy, x, kernel_size, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 +): + """Create a CuDNN extern op that computes the gradient of 2D convolution with respect to weight. + + Parameters + ---------- + dy: Tensor + output gradient + x: Tensor + input tensor + kernel_size: a pair of int + The spatial size of the corresponding forward convolution kernel + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + conv_mode: int + 0: CUDNN_CONVOLUTION + 1: CUDNN_CROSS_CORRELATION + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + conv_dtype: str + convolution type + groups: int + the number of groups + + Returns + ------- + dw: Tensor + wgrad tensor + """ + dims = len(x.shape) + assert dims == 4 + + conv_dtype = x.dtype if conv_dtype is None else conv_dtype + pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) + filter_h, filter_w = kernel_size + + x_shape = list(x.shape) + + assert isinstance( + x.shape[0], tvm.tir.expr.IntImm + ), "Dynamic batch is not supported for cudnn conv2d backwad filter yet." + + if tensor_format == 0: + dw_shape = [dy.shape[1], x_shape[1], filter_h, filter_w] + else: + dw_shape = [dy.shape[3], filter_h, filter_w, x_shape[3]] + + algo = conv_backward_filter_find_algo( + tensor_format, + pad, + stride, + dilation, + list(dy.shape), + list(x.shape), + dw_shape, + x.dtype, + conv_dtype, + groups, + ) + + return te.extern( + dw_shape, + [dy, x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.conv2d.backward_filter", + conv_mode, + tensor_format, + algo, + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + ins[0], + ins[1], + outs[0], + conv_dtype, + groups, + ), + name="dw", + ) + + def softmax(x, axis=-1): """Compute softmax using CuDNN diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 2a941cc8c28a..1fa909e748a0 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1062,6 +1062,10 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_injective_schedule("nn.batch_to_space_nd") +reg.register_strategy("nn.conv2d_backward_weight", strategy.conv2d_backward_weight_strategy) +reg.register_pattern("nn.conv2d_backward_weight", OpPattern.OUT_ELEMWISE_FUSABLE) + + @reg.register_legalize("nn.conv2d_backward_weight") def legalize_conv2d_backward_weight(attrs, inputs, types): """Legalize conv2d_backward_weight op. diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 69579f690c96..af7451408d27 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -564,6 +564,25 @@ def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target): return strategy +@conv2d_backward_weight_strategy.register(["cuda"]) +def conv2d_backward_weight_strategy_cuda(attrs, inputs, out_type, target): + """conv2d_backward_weight cuda strategy""" + strategy = _op.OpStrategy() + if target.kind.name == "cuda" and "cudnn" in target.libs: + strategy.add_implementation( + wrap_compute_conv2d_backward_weight(topi.cuda.conv2d_backward_weight_cudnn), + wrap_topi_schedule(topi.generic.schedule_extern), + name="conv2d_backward_weight_strategy.cudnn", + plevel=15, + ) + else: + raise RuntimeError( + "conv2d_backward_weight on cuda is currently only supported with cudnn. " + "Please run Legalize pass to decompose this op into supported ops." + ) + return strategy + + @conv2d_transpose_strategy.register(["cuda", "gpu"]) def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): """conv2d_transpose cuda strategy""" @@ -579,6 +598,15 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw), name="conv2d_transpose_nchw.cuda", ) + + if target.kind.name == "cuda" and "cudnn" in target.libs and attrs.kernel_layout == "IOHW": + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_cudnn), + wrap_topi_schedule(topi.generic.schedule_extern), + name="conv2d_transpose.cudnn.cuda", + plevel=25, + ) + # TODO(masahi): Support conv2d_transpose NHWC. return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index cc12fa127006..abd3e28bc3eb 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1841,3 +1841,41 @@ def einsum_strategy(attrs, inputs, out_type, target): name="einsum.generic", ) return strategy + + +# conv2d_backward_weight +def wrap_compute_conv2d_backward_weight(topi_compute): + """wrap conv2d_backward_weight topi compute""" + + def _compute_conv2d_backward_weight(attrs, inputs, out_dtype): + kernel_size = get_const_tuple(attrs.kernel_size) + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + out_dtype = attrs.out_dtype + layout = attrs.data_layout + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype + out = topi_compute( + inputs[0], + inputs[1], + kernel_size, + padding, + strides, + dilation, + groups, + layout, + out_dtype, + ) + return [out] + + return _compute_conv2d_backward_weight + + +@override_native_generic_func("conv2d_backward_weight_strategy") +def conv2d_backward_weight_strategy(attrs, inputs, out_type, target): + """wgrad generic strategy""" + raise RuntimeError( + "conv2d_backward_weight is currently only supported with cudnn. " + "Please run Legalize pass to decompose this op into supported ops." + ) diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index bd8d7ec19bb3..15fcaaa02134 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -123,3 +123,22 @@ def conv2d_cudnn( def schedule_conv2d_cudnn(cfg, outs): """Create the schedule for conv2d_cudnn""" return generic.schedule_extern(outs) + + +def conv2d_backward_weight_cudnn( + dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype +): + """Compute conv2d wgrad using CuDNN library""" + assert layout in ["NCHW", "NHWC"] + return cudnn.conv_backward_filter( + dy, + x, + kernel_size, + padding, + stride, + dilation, + conv_mode=1, + tensor_format=0 if layout == "NCHW" else 1, + conv_dtype=output_dtype, + groups=groups, + ) diff --git a/python/tvm/topi/cuda/conv2d_transpose_nchw.py b/python/tvm/topi/cuda/conv2d_transpose_nchw.py index 3b704170a2e9..36ce3a3d2454 100644 --- a/python/tvm/topi/cuda/conv2d_transpose_nchw.py +++ b/python/tvm/topi/cuda/conv2d_transpose_nchw.py @@ -19,6 +19,7 @@ import tvm from tvm import te +from tvm.contrib import cudnn from tvm import autotvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from .. import nn @@ -286,3 +287,10 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +def conv2d_transpose_cudnn(x, w, stride, padding, out_dtype, output_padding=(0, 0)): + """Compute conv2d_tranpose using cudnn dgrad kernel""" + return cudnn.conv_backward_data( + x, w, padding, stride, (1, 1), 1, 0, out_dtype, groups=1, output_padding=output_padding + ) diff --git a/python/tvm/topi/nn/conv2d_transpose.py b/python/tvm/topi/nn/conv2d_transpose.py index 2871699350ed..c408095eb7ab 100644 --- a/python/tvm/topi/nn/conv2d_transpose.py +++ b/python/tvm/topi/nn/conv2d_transpose.py @@ -298,7 +298,6 @@ def conv2d_transpose_legalize(attrs, inputs, types): result : tvm.relay.Expr The legalized expr """ - data, kernel = inputs kernel_layout = attrs["kernel_layout"] if attrs["data_layout"] == "NHWC": diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 75eabffc957a..c3d222cfd120 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -75,4 +75,4 @@ from .nll_loss import nll_loss from .dense import dense from .searchsorted import searchsorted_ref -from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python +from .conv2d_backcward_weight_python import conv2d_backward_weight_python diff --git a/python/tvm/topi/testing/conv2d_backcward_weight_python.py b/python/tvm/topi/testing/conv2d_backcward_weight_python.py index 587cd45b49c1..36a6b0616053 100644 --- a/python/tvm/topi/testing/conv2d_backcward_weight_python.py +++ b/python/tvm/topi/testing/conv2d_backcward_weight_python.py @@ -42,7 +42,7 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding Returns ------- - b_np : np.ndarray + dw_np : np.ndarray 4-D with shape [num_filter, in_channel, filter_height, filter_width] """ @@ -74,3 +74,45 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding dw[k, c, r, s] = acc return dw + + +def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"): + """Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout. + + Parameters + ---------- + dy_np : numpy.ndarray + 4-D with shape [batch, in_channel, out_height, out_width] for NCHW layout + + x_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] for NCHW layout + + kernel_size : tuple of two ints + Height and width of the weight + + stride : tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : tuple of two ints + Spatial padding, or [pad_h, pad_w] + + layout: string + Layout of dy_np and x_np + + Returns + ------- + dw_np : np.ndarray + Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout, + [num_filter, filter_height, filter_width, in_channel] for NHWC layout. + """ + if layout == "NCHW": + return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding) + + dw_np_oihw = conv2d_backward_weight_nchw_python( + np.transpose(dy_np, [0, 3, 1, 2]), + np.transpose(x_np, [0, 3, 1, 2]), + kernel_size, + stride, + padding, + ) + return np.transpose(dw_np_oihw, [0, 2, 3, 1]) diff --git a/python/tvm/topi/testing/conv2d_transpose_python.py b/python/tvm/topi/testing/conv2d_transpose_python.py index a38d8bc9f031..678b5fe5d003 100644 --- a/python/tvm/topi/testing/conv2d_transpose_python.py +++ b/python/tvm/topi/testing/conv2d_transpose_python.py @@ -73,7 +73,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding): dilated_a_np.shape[2] + bpad_top + bpad_bottom, dilated_a_np.shape[3] + bpad_left + bpad_right, ) - ) + ).astype(a_np.dtype) padded_a_np[ :, :, @@ -83,7 +83,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding): # convolution stage out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w - b_np = np.zeros((batch, out_c, out_h, out_w)) + b_np = np.zeros((batch, out_c, out_h, out_w)).astype(a_np.dtype) for n in range(batch): for f in range(out_c): for c in range(in_c): diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index f1d4eb3d87ea..30386bbf4415 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -665,7 +665,6 @@ given the original input data and the output gradient. .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel) - .set_attr("TNonComputational", true) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); } // namespace relay diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc new file mode 100644 index 000000000000..af190d7c8c90 --- /dev/null +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file cuDNN kernel calls for backward algorithms. + */ +#include +#include +#include + +#include "cudnn_utils.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; + +void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[], + const int stride[], const int dilation[], DLTensor* dy, DLTensor* w, + DLTensor* dx, const std::string& conv_dtype) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx->shape, w->shape, + dy->shape, dy->dtype, conv_dtype); + // Set Device + entry_ptr->conv_entry.device = dy->device; + // Set Algo + entry_ptr->conv_entry.bwd_data_algo = static_cast(algo); + + // Set workspace + size_t workspace_size = 0; + CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.bwd_data_algo, &workspace_size)); + entry_ptr->conv_entry.UpdateWorkspace(workspace_size); + CUDNN_CALL(cudnnConvolutionBackwardData( + entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.output_desc, dy->data, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_data_algo, + entry_ptr->conv_entry.workspace, workspace_size, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.input_desc, + dx->data)); +} + +void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int dy_dim[], const int w_dim[], + const int dx_dim[], const std::string& data_dtype, + const std::string& conv_dtype, TVMRetValue* ret) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + const int full_dims = dims + 2; + std::vector dy_dim_int64(full_dims); + std::vector w_dim_int64(full_dims); + std::vector dx_dim_int64(full_dims); + for (int i = 0; i < full_dims; ++i) { + dy_dim_int64[i] = dy_dim[i]; + w_dim_int64[i] = w_dim[i]; + dx_dim_int64[i] = dx_dim[i]; + } + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx_dim_int64.data(), + w_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype), + conv_dtype); + + int returned_algo_count = 0; + + cudnnConvolutionBwdDataAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT]; + CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results)); + + const std::vector bwd_data_algo_names{ + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", // non-deterministic + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED"}; + + auto best_algo = perf_results[0].algo; + LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd data algorithms, choosing " + << bwd_data_algo_names[best_algo]; + for (int i = 0; i < returned_algo_count; ++i) { + LOG(INFO) << "\t\t" << i << ") " << bwd_data_algo_names[perf_results[i].algo] + << " - time: " << perf_results[i].time << " ms" + << ", Memory: " << perf_results[i].memory; + } + + ret[0] = best_algo; +} + +void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int groups, + const int pad[], const int stride[], const int dilation[], + DLTensor* dy, DLTensor* x, DLTensor* dw, + const std::string& conv_dtype) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, dw->shape, + dy->shape, x->dtype, conv_dtype); + // Set Device + entry_ptr->conv_entry.device = x->device; + // Set Algo + entry_ptr->conv_entry.bwd_filter_algo = static_cast(algo); + + // Set workspace + size_t workspace_size = 0; + CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.bwd_filter_algo, &workspace_size)); + entry_ptr->conv_entry.UpdateWorkspace(workspace_size); + CUDNN_CALL(cudnnConvolutionBackwardFilter( + entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.output_desc, dy->data, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_filter_algo, + entry_ptr->conv_entry.workspace, workspace_size, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.filter_desc, dw->data)); +} + +void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int dy_dim[], const int x_dim[], + const int dw_dim[], const std::string& data_dtype, + const std::string& conv_dtype, TVMRetValue* ret) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + const int full_dims = dims + 2; + std::vector x_dim_int64(full_dims); + std::vector dy_dim_int64(full_dims); + std::vector dw_dim_int64(full_dims); + for (int i = 0; i < full_dims; ++i) { + x_dim_int64[i] = x_dim[i]; + dy_dim_int64[i] = dy_dim[i]; + dw_dim_int64[i] = dw_dim[i]; + } + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(), + dw_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype), + conv_dtype); + + int returned_algo_count = 0; + + cudnnConvolutionBwdFilterAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT]; + CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT, &returned_algo_count, perf_results)); + + const std::vector bwd_filter_algo_names{ + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0", // non-deterministic + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1", + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT", + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3", + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED", + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING", + }; + + auto best_algo = perf_results[0].algo; + LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd filter algorithms, choosing " + << bwd_filter_algo_names[best_algo]; + for (int i = 0; i < returned_algo_count; ++i) { + LOG(INFO) << "\t\t" << i << ") " << bwd_filter_algo_names[perf_results[i].algo] + << " - time: " << perf_results[i].time << " ms" + << ", Memory: " << perf_results[i].memory; + } + + ret[0] = best_algo; +} + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + DLTensor* dy = args[9]; + DLTensor* w = args[10]; + DLTensor* dx = args[11]; + std::string conv_dtype = args[12]; + int groups = args[13]; + + ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, w, dx, + conv_dtype); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* dy_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + int* dx_dim = static_cast(static_cast(args[7])); + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, w_dim, dx_dim, + data_dtype, conv_dtype, ret); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + DLTensor* dy = args[9]; + DLTensor* x = args[10]; + DLTensor* dw = args[11]; + std::string conv_dtype = args[12]; + int groups = args[13]; + + ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, x, + dw, conv_dtype); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* dy_dim = static_cast(static_cast(args[5])); + int* x_dim = static_cast(static_cast(args[6])); + int* dw_dim = static_cast(static_cast(args[7])); + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, x_dim, dw_dim, + data_dtype, conv_dtype, ret); + }); + +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index b7476e5106fa..f5e5ee889c55 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -18,7 +18,7 @@ */ /*! - * \file Use external cudnn utils function + * \file cuDNN kernel calls for the forward algorithm. */ #include #include @@ -147,7 +147,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo") .set_body([](TVMArgs args, TVMRetValue* ret) { int format = args[0]; int dims = args[1]; diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 89de0e90df90..426ccfdf37af 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -67,12 +67,14 @@ inline void GetCudnnStride(int nbdim, const int* dims, int* strides) { struct ConvEntry { cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION}; - cudnnFilterDescriptor_t filter_desc; cudnnDataType_t data_type; cudnnTensorFormat_t tensor_format; cudnnTensorDescriptor_t input_desc; + cudnnFilterDescriptor_t filter_desc; cudnnTensorDescriptor_t output_desc; cudnnConvolutionFwdAlgo_t fwd_algo; + cudnnConvolutionBwdDataAlgo_t bwd_data_algo; + cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo; // cudnnMathType_t math_type; Device device; runtime::DeviceAPI* cuda_api; diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index bc2cc80f362d..0c39a1a2428d 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -236,6 +236,144 @@ def test_softmax(): verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True) +def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-5): + batch = 3 + in_channel = 4 + out_channel = 16 + filter_h, filter_w = 3, 3 + pad_h, pad_w = 1, 1 + stride_h, stride_w = 1, 1 + height, width = 32, 32 + + if tensor_format == 0: + xshape = [batch, in_channel, height, width] + wshape = [out_channel, in_channel, filter_h, filter_w] + oshape = xshape + oshape[1] = out_channel + ref_func = tvm.topi.testing.conv2d_transpose_nchw_python + else: + xshape = [batch, height, width, in_channel] + wshape = [out_channel, filter_h, filter_w, in_channel] + oshape = xshape + oshape[3] = out_channel + ref_func = lambda dy_np, w_np, strides, padding, out_pad: tvm.topi.testing.conv2d_transpose_nhwc_python( + dy_np, np.transpose(w_np, [1, 2, 3, 0]), "HWOI", strides, padding, out_pad + ) + + dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype) + w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) + + if data_dtype == "float16": + dx_np = ref_func( + dy_np.astype("float32"), + w_np.astype("float32"), + (stride_h, stride_w), + (pad_h, pad_w), + (0, 0), + ) + dx_np = dx_np.astype("float16") + else: + dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0)) + + dy = te.placeholder(oshape, name="dy", dtype=data_dtype) + w = te.placeholder(wshape, name="dw", dtype=data_dtype) + dx = cudnn.conv_backward_data( + dy, + w, + [pad_h, pad_w], + [stride_h, stride_w], + [1, 1], + conv_mode=1, + tensor_format=tensor_format, + conv_dtype=conv_dtype, + groups=1, + ) + + s = te.create_schedule(dx.op) + + dev = tvm.cuda(0) + f = tvm.build(s, [dy, w, dx], "cuda --host=llvm", name="conv2d_backward_data") + + dy = tvm.nd.array(dy_np, dev) + w = tvm.nd.array(w_np, dev) + dx = tvm.nd.array(dx_np, dev) + + f(dy, w, dx) + tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol) + + +@tvm.testing.requires_gpu +@requires_cudnn +def test_conv2d_backward_data(): + verify_conv2d_backward_data("float32", "float32", tensor_format=0, tol=1e-5) + verify_conv2d_backward_data("float32", "float32", tensor_format=1, tol=1e-2) + # The scipy convolve function does not support fp16, so the reference will be computed with + # fp32. Use larger tolerance to be on the safe side (1e-2 also seems mostly ok). + verify_conv2d_backward_data("float16", "float16", tensor_format=1, tol=1e-1) + + +def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1e-5): + batch = 3 + in_channel = 4 + out_channel = 16 + filter_h, filter_w = 3, 3 + pad_h, pad_w = 1, 1 + stride_h, stride_w = 1, 1 + height, width = 32, 32 + + if tensor_format == 0: + x_shape = [batch, in_channel, height, width] + dy_shape = [batch, out_channel, height, width] + else: + x_shape = [batch, height, width, in_channel] + dy_shape = [batch, height, width, out_channel] + + x_np = np.random.uniform(-1, 1, x_shape).astype(data_dtype) + dy_np = np.random.uniform(-1, 1, dy_shape).astype(data_dtype) + + dw_np = tvm.topi.testing.conv2d_backward_weight_python( + dy_np, + x_np, + (filter_h, filter_w), + (stride_h, stride_w), + (pad_h, pad_w), + "NCHW" if tensor_format == 0 else "NHWC", + ) + + x = te.placeholder(x_shape, name="x", dtype=data_dtype) + dy = te.placeholder(dy_shape, name="dy", dtype=data_dtype) + dw = cudnn.conv_backward_filter( + dy, + x, + (filter_h, filter_w), + [pad_h, pad_w], + [stride_h, stride_w], + [1, 1], + conv_mode=1, + tensor_format=tensor_format, + conv_dtype=conv_dtype, + ) + + s = te.create_schedule(dw.op) + + dev = tvm.cuda(0) + f = tvm.build(s, [dy, x, dw], "cuda --host=llvm", name="conv2d_backward_filter") + + x = tvm.nd.array(x_np, dev) + dy = tvm.nd.array(dy_np, dev) + dw = tvm.nd.array(dw_np, dev) + + f(dy, x, dw) + tvm.testing.assert_allclose(dw.numpy(), dw_np, atol=tol, rtol=tol) + + +@tvm.testing.requires_gpu +@requires_cudnn +def test_conv2d_backward_filter(): + verify_conv2d_backward_filter("float32", "float32", tensor_format=0, tol=1e-4) + verify_conv2d_backward_filter("float32", "float32", tensor_format=1, tol=1e-4) + + test_kwargs_default_2d = { "tensor_format": 0, "pad": [1, 1], diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 1efdb262245f..a5fc630f61dc 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -233,27 +233,28 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin dtype = "float32" dy = relay.var("dy", shape=dy_shape, dtype=dtype) x = relay.var("x", shape=x_shape, dtype=dtype) - dw = relay.nn.conv2d_backward_weight( - dy, x, strides=stride, padding=padding, kernel_size=kernel_size + dw_func = relay.Function( + [dy, x], + relay.nn.conv2d_backward_weight( + dy, x, strides=stride, padding=padding, kernel_size=kernel_size + ), ) - dw_func = relay.Function([dy, x], dw) dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize()) - target = "llvm" - dev = tvm.device(target, 0) - dy_np = np.random.randn(*dy_shape).astype(dtype) - x_np = np.random.randn(*x_shape).astype(dtype) + for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda -libs=cudnn")]: + if "cudnn" in target and not tvm.contrib.cudnn.exists(): + continue - dw_np = ( - relay.create_executor(device=dev, target=target) - .evaluate(dw_func_legalized)(dy_np, x_np) - .numpy() - ) - ref_dw_np = tvm.topi.testing.conv2d_backward_weight_nchw_python( - dy_np, x_np, kernel_size, stride, padding - ) + dev = tvm.device(target, 0) + dy_np = np.random.randn(*dy_shape).astype(dtype) + x_np = np.random.randn(*x_shape).astype(dtype) + + dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy() + ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python( + dy_np, x_np, kernel_size, stride, padding + ) - np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4) def test_conv2d_backward_weight(): diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index db712be4262e..6d428bfde21b 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -24,7 +24,7 @@ import tvm.testing import tvm.topi.testing from tvm import autotvm, relay, te -from tvm.contrib import utils +from tvm.contrib import utils, cudnn from tvm.ir.module import IRModule from tvm.relay import transform from tvm.relay.testing import run_infer_type @@ -838,10 +838,10 @@ def test_conv2d_transpose_infer_type(): @tvm.testing.uses_gpu def test_conv2d_transpose_nchw_run(): k_layouts = {"OIHW": (10, 3, 3, 3), "IOHW": (3, 10, 3, 3)} + output_padding = (1, 1) for k_layout, kshape in k_layouts.items(): dshape = (1, 3, 18, 18) - oshape = (1, 10, 36, 36) x = relay.var("x", shape=dshape) w = relay.var("w") y = relay.nn.conv2d_transpose( @@ -851,7 +851,7 @@ def test_conv2d_transpose_nchw_run(): kernel_size=(3, 3), strides=(2, 2), padding=(1, 1), - output_padding=(1, 1), + output_padding=output_padding, kernel_layout=k_layout, data_layout="NCHW", ) @@ -866,9 +866,16 @@ def test_conv2d_transpose_nchw_run(): else: kernel_iohw = kernel - ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(data, kernel_iohw, 2, 1, (1, 1)) + ref_res = tvm.topi.testing.conv2d_transpose_nchw_python( + data, kernel_iohw, 2, 1, output_padding + ) - for target, dev in tvm.testing.enabled_targets(): + enabled_targets = tvm.testing.enabled_targets() + + if cudnn.exists() and k_layout == "IOHW": + enabled_targets.append(("cuda -libs=cudnn", tvm.cuda(0))) + + for target, dev in enabled_targets: op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( data, kernel )