diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 97704986792d..799cfa7ae992 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -35,94 +35,11 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co const int stride[], const int dilation[], DLTensor* x, DLTensor* w, DLTensor* y, const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - // Set Mode - entry_ptr->conv_entry.mode = static_cast(mode); - // Set Format - entry_ptr->conv_entry.tensor_format = static_cast(format); - // Set Algo - entry_ptr->conv_entry.fwd_algo = static_cast(algo); - // Set Device - entry_ptr->conv_entry.device = x->device; - // Set Data Type - entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype)); - cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); - // Dims includes N and C - int full_dims = dims + 2; - - std::vector dim(full_dims); - std::vector tensor_stride(full_dims); - - // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error - // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int - - CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - if (dims == 2) { - // Set Desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor( - entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], - dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); - int ni, ci, hi, wi; - if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { - ni = 0; - ci = 3; - hi = 1; - wi = 2; - } else { - ni = 0; - ci = 1; - hi = 2; - wi = 3; - } - - // Set Filter - CUDNN_CALL(cudnnSetFilter4dDescriptor( - entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, - static_cast(w->shape[ni]), static_cast(w->shape[ci]), - static_cast(w->shape[hi]), static_cast(w->shape[wi]))); - // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(x->shape[ni]), static_cast(x->shape[ci]), - static_cast(x->shape[hi]), static_cast(x->shape[wi]))); - // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(y->shape[ni]), static_cast(y->shape[ci]), - static_cast(y->shape[hi]), static_cast(y->shape[wi]))); - } else { - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, - dilation, entry_ptr->conv_entry.mode, - entry_ptr->conv_entry.data_type)); - - // Set Filter - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(w->shape[i]); - } - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, full_dims, - dim.data())); - // Set Input - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(x->shape[i]); - } - GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, - dim.data(), tensor_stride.data())); - // Set Output - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(y->shape[i]); - } - GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, - dim.data(), tensor_stride.data())); - } - - if (cudnnGetVersion() > 7000) { - CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) - } + SetConvDescriptors(entry_ptr, mode, format, algo, dims, groups, pad, stride, dilation, x, w, y, + conv_dtype); - // Set workspace - size_t workspace_size = 0; + // Set workspace + size_t workspace_size = 0; CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 297cd9e7a361..0738371da805 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -23,6 +23,7 @@ #include "cudnn_utils.h" #include +#include #include namespace tvm { @@ -160,6 +161,98 @@ void ConvEntry::CleanWorkspace() { workspace_size = 0; } +void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int mode, int format, int algo, int dims, + int groups, const int pad[], const int stride[], const int dilation[], + DLTensor* x, DLTensor* w, DLTensor* y, const std::string& conv_dtype) { + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + // Set Format + entry_ptr->conv_entry.tensor_format = static_cast(format); + // Set Algo + entry_ptr->conv_entry.fwd_algo = static_cast(algo); + // Set Device + entry_ptr->conv_entry.device = x->device; + // Set Data Type + entry_ptr->conv_entry.data_type = + CuDNNDataType::DLTypeToCuDNNType(runtime::String2DLDataType(conv_dtype)); + + cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + // Dims includes N and C + int full_dims = dims + 2; + + std::vector dim(full_dims); + std::vector tensor_stride(full_dims); + + // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error + // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int + + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); + if (dims == 2) { + // Set Desc + CUDNN_CALL(cudnnSetConvolution2dDescriptor( + entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], + dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); + int ni, ci, hi, wi; + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + ni = 0; + ci = 3; + hi = 1; + wi = 2; + } else { + ni = 0; + ci = 1; + hi = 2; + wi = 3; + } + + // Set Filter + CUDNN_CALL(cudnnSetFilter4dDescriptor( + entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, + static_cast(w->shape[ni]), static_cast(w->shape[ci]), + static_cast(w->shape[hi]), static_cast(w->shape[wi]))); + // Set Input + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, + static_cast(x->shape[ni]), static_cast(x->shape[ci]), + static_cast(x->shape[hi]), static_cast(x->shape[wi]))); + // Set Output + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, + static_cast(y->shape[ni]), static_cast(y->shape[ci]), + static_cast(y->shape[hi]), static_cast(y->shape[wi]))); + } else { + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, entry_ptr->conv_entry.mode, + entry_ptr->conv_entry.data_type)); + + // Set Filter + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(w->shape[i]); + } + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, + dim.data())); + // Set Input + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(x->shape[i]); + } + GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + dim.data(), tensor_stride.data())); + // Set Output + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(y->shape[i]); + } + GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, + dim.data(), tensor_stride.data())); + } + + if (cudnnGetVersion() > 7000) { + CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) + } +} + // SoftmaxEntry SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); } diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 01b92d61e66e..61ac4e51c404 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -103,6 +103,10 @@ struct CuDNNThreadEntry { static CuDNNThreadEntry* ThreadLocal(bool check_exists = true); }; // CuDNNThreadEntry +void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int mode, int format, int algo, int dims, + int groups, const int pad[], const int stride[], const int dilation[], + DLTensor* x, DLTensor* w, DLTensor* y, const std::string& conv_dtype); + } // namespace contrib } // namespace tvm