Skip to content

Commit

Permalink
adding thrust scan support
Browse files Browse the repository at this point in the history
  • Loading branch information
masa authored and masahi committed Dec 24, 2020
1 parent 65634e8 commit ac13b40
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <thrust/scan.h>

#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
Expand Down Expand Up @@ -245,5 +246,86 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
}
});

template<typename InType, typename OutType>
void thrust_scan(DLTensor* data,
DLTensor* output,
bool exclusive) {
thrust::device_ptr<InType> data_ptr(static_cast<InType *>(data->data));
thrust::device_ptr<OutType> output_ptr(static_cast<OutType *>(output->data));
const auto scan_size = data->shape[data->ndim - 1];

if (data->ndim == 1) {
if (exclusive) {
thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
} else {
thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
}
} else {
auto counting_iter = thrust::counting_iterator<int>(0);
auto key_iter = thrust::make_transform_iterator(counting_iter, [scan_size] __device__(int i) {
return i / scan_size;
});
int64_t size = 0;
for (int i = 0; i < data->ndim; ++i) size += data->shape[i];

if (exclusive) {
thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
} else {
thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
}
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_EQ(args.num_args, 3);
DLTensor* data = args[0];
DLTensor* output = args[1];
bool exclusive = args[2];

auto in_dtype = DLDataType2String(data->dtype);
auto out_dtype = DLDataType2String(output->dtype);

if (in_dtype == "int32") {
if (out_dtype == "int32") {
thrust_scan<int, int>(data, output, exclusive);
} else if (out_dtype == "int64") {
thrust_scan<int, int64_t>(data, output, exclusive);
} else if (out_dtype == "float32") {
thrust_scan<int, float>(data, output, exclusive);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
// } else if (in_dtype == "int64") {
// if (out_dtype == "int32") {
// thrust_scan<int64_t, int>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else if (out_dtype == "int64") {
// thrust_scan<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else if (out_dtype == "float32") {
// thrust_scan<int64_t, float>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else {
// LOG(FATAL) << "Unsupported value dtype: " << out_dtype;
// }
// } else if (in_dtype == "float32") {
// if (out_dtype == "int32") {
// thrust_scan<float, int>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else if (out_dtype == "int64") {
// thrust_scan<float, int64_t>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else if (out_dtype == "float32") {
// thrust_scan<float, float>(keys_in, values_in, keys_out, values_out,
// for_scatter);
// } else {
// LOG(FATAL) << "Unsupported value dtype: " << out_dtype;
// }
} else {
LOG(FATAL) << "Unsupported input dtype: " << in_dtype;
}
});

} // namespace contrib
} // namespace tvm

0 comments on commit ac13b40

Please sign in to comment.