From ac13b407e21a83ca57240cad205c32a5d000f999 Mon Sep 17 00:00:00 2001 From: masa Date: Fri, 18 Dec 2020 19:49:25 +0900 Subject: [PATCH] adding thrust scan support --- src/runtime/contrib/thrust/thrust.cu | 82 ++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index dddbb043fddc..c114527bfc7d 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -245,5 +246,86 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") } }); +template +void thrust_scan(DLTensor* data, + DLTensor* output, + bool exclusive) { + thrust::device_ptr data_ptr(static_cast(data->data)); + thrust::device_ptr output_ptr(static_cast(output->data)); + const auto scan_size = data->shape[data->ndim - 1]; + + if (data->ndim == 1) { + if (exclusive) { + thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } else { + thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } + } else { + auto counting_iter = thrust::counting_iterator(0); + auto key_iter = thrust::make_transform_iterator(counting_iter, [scan_size] __device__(int i) { + return i / scan_size; + }); + int64_t size = 0; + for (int i = 0; i < data->ndim; ++i) size += data->shape[i]; + + if (exclusive) { + thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } else { + thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } + } +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.num_args, 3); + DLTensor* data = args[0]; + DLTensor* output = args[1]; + bool exclusive = args[2]; + + auto in_dtype = DLDataType2String(data->dtype); + auto out_dtype = DLDataType2String(output->dtype); + + if (in_dtype == "int32") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + // } else if (in_dtype == "int64") { + // if (out_dtype == "int32") { + // thrust_scan(keys_in, values_in, keys_out, values_out, + // for_scatter); + // } else if (out_dtype == "int64") { + // thrust_scan(keys_in, values_in, keys_out, values_out, + // for_scatter); + // } else if (out_dtype == "float32") { + // thrust_scan(keys_in, values_in, keys_out, values_out, + // for_scatter); + // } else { + // LOG(FATAL) << "Unsupported value dtype: " << out_dtype; + // } + // } else if (in_dtype == "float32") { + // if (out_dtype == "int32") { + // thrust_scan(keys_in, values_in, keys_out, values_out, + // for_scatter); + // } else if (out_dtype == "int64") { + // thrust_scan(keys_in, values_in, keys_out, values_out, + // for_scatter); + // } else if (out_dtype == "float32") { + // thrust_scan(keys_in, values_in, keys_out, values_out, + // for_scatter); + // } else { + // LOG(FATAL) << "Unsupported value dtype: " << out_dtype; + // } + } else { + LOG(FATAL) << "Unsupported input dtype: " << in_dtype; + } +}); + } // namespace contrib } // namespace tvm