Skip to content

Commit

Permalink
remove cudnn get output
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 17, 2022
1 parent dcbd9c9 commit 834f54a
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 147 deletions.
62 changes: 0 additions & 62 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,68 +285,6 @@ def conv_output_shape(
return output


def _conv_output_shape_from_cudnn(
tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1
):
"""Get output shape of 2D or 3D convolution. The output of this
function should be identical to that of conv_output_shape, but
requires a GPU with CuDNN to be present. This is maintained for
testing purposes to validate the output of conv_output_shape.
Paramters
---------
tensor_format: int
0: CUDNN_TENSOR_NCHW
1: CUDNN_TENSOR_NHWC
2: CUDNN_TENSOR_NCHW_VECT_C
pad: int or list
padding
stride: int or list
stride
dilation: int or list
dilation
x_shape: list
input shape
w_shape: list
weight shape
data_dtype: str
data type
conv_dtype: str
convolution type
groups: int
number of groups
Returns
-------
oshape: list
output shape
"""
dims = len(x_shape)
assert dims in (4, 5)

pad, stride, dilation, xshape, wshape = _prepare_global_func_params(
dims - 2, pad, stride, dilation, x_shape, w_shape
)
oshape = np.zeros((dims), dtype=np.int32)

func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn")
func(
tensor_format,
dims - 2,
_get_np_int32_array_handle(pad),
_get_np_int32_array_handle(stride),
_get_np_int32_array_handle(dilation),
_get_np_int32_array_handle(xshape),
_get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(oshape),
data_dtype,
conv_dtype,
groups,
)
return list(oshape)


def conv_find_algo(
tensor_format,
pad,
Expand Down
72 changes: 0 additions & 72 deletions src/runtime/contrib/cudnn/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,60 +60,6 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
entry_ptr->conv_entry.output_desc, y->data));
}

void OutputShape(int format, int dims, int groups, const int pad[], const int stride[],
const int dilation[], const int x_dim[], const int w_dim[], void* out_shape,
const std::string& data_dtype, const std::string& conv_dtype) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();

// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype));
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Dims includes N and C
int full_dims = dims + 2;

// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
dilation, CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));

if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";

// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format, data_type, x_dim[0],
x_dim[3], x_dim[1], x_dim[2]));

// filter desc
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3],
w_dim[1], w_dim[2]));

CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc, static_cast<int*>(out_shape),
static_cast<int*>(out_shape) + 3, static_cast<int*>(out_shape) + 1,
static_cast<int*>(out_shape) + 2));
} else {
// Set Input
std::vector<int> tensor_stride(full_dims);
GetCudnnStride(full_dims, x_dim, tensor_stride.data());

CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
x_dim, tensor_stride.data()));
// filter desc
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
entry_ptr->conv_entry.tensor_format, full_dims, w_dim));

CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc, full_dims, static_cast<int*>(out_shape)));
}
}

void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
Expand Down Expand Up @@ -201,24 +147,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
conv_dtype);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape_from_cudnn")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int format = args[0];
int dims = args[1];
int* pad = static_cast<int*>(static_cast<void*>(args[2]));
int* stride = static_cast<int*>(static_cast<void*>(args[3]));
int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
void* out_shape = args[7];
std::string data_dtype = args[8];
std::string conv_dtype = args[9];
int groups = args[10];

OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype,
conv_dtype);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int format = args[0];
Expand Down
14 changes: 2 additions & 12 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


requires_cudnn = pytest.mark.skipif(
tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True) is None,
tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True) is None,
reason="CuDNN is not enabled",
)

Expand Down Expand Up @@ -307,15 +307,5 @@ def conv_output_shape_kwargs(request):
return request.param


@tvm.testing.requires_gpu
@requires_cudnn
def test_conv_output_shape(conv_output_shape_kwargs):
shape_from_cudnn = cudnn._conv_output_shape_from_cudnn(**conv_output_shape_kwargs)
shape_from_python = cudnn.conv_output_shape(**conv_output_shape_kwargs)
assert shape_from_cudnn == shape_from_python


if __name__ == "__main__":
# sys.exit(pytest.main(sys.argv))
test_conv2d()
test_conv3d()
sys.exit(pytest.main(sys.argv))
2 changes: 1 addition & 1 deletion tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def verify_any_conv2d(
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)

targets = None
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True):
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True):
targets = [("cuda -libs=cudnn", tvm.cuda(0))]

check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)
Expand Down

0 comments on commit 834f54a

Please sign in to comment.