Skip to content

Commit

Permalink
Use cudf::thread_index_type in replace.cu. (#13905)
Browse files Browse the repository at this point in the history
This PR uses `cudf::thread_index_type` in `replace.cu` to avoid risk of overflow.

Authors:
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Yunsong Wang (https://github.com/PointKernel)
  - Vukasin Milovanovic (https://github.com/vuule)
  - Nghia Truong (https://github.com/ttnghia)

URL: #13905
  • Loading branch information
bdice authored Sep 2, 2023
1 parent bbbb143 commit 0c829cc
Showing 1 changed file with 53 additions and 38 deletions.
91 changes: 53 additions & 38 deletions cpp/src/replace/replace.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,40 +127,42 @@ __global__ void replace_strings_first_pass(cudf::column_device_view input,
cudf::size_type* __restrict__ output_valid_count)
{
cudf::size_type nrows = input.size();
cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x;
auto tid = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
uint32_t active_mask = 0xffff'ffffu;
active_mask = __ballot_sync(active_mask, i < nrows);
active_mask = __ballot_sync(active_mask, tid < nrows);
auto const lane_id{threadIdx.x % cudf::detail::warp_size};
uint32_t valid_sum{0};

while (i < nrows) {
while (tid < nrows) {
auto const idx = static_cast<cudf::size_type>(tid);
bool input_is_valid = true;

if (input_has_nulls) input_is_valid = input.is_valid_nocheck(i);
if (input_has_nulls) input_is_valid = input.is_valid_nocheck(idx);
bool output_is_valid = input_is_valid;

if (input_is_valid) {
int result = get_new_string_value(i, input, values_to_replace, replacement);
cudf::string_view output = (result == -1) ? input.element<cudf::string_view>(i)
int result = get_new_string_value(idx, input, values_to_replace, replacement);
cudf::string_view output = (result == -1) ? input.element<cudf::string_view>(idx)
: replacement.element<cudf::string_view>(result);
offsets.data<cudf::size_type>()[i] = output.size_bytes();
indices.data<cudf::size_type>()[i] = result;
offsets.data<cudf::size_type>()[idx] = output.size_bytes();
indices.data<cudf::size_type>()[idx] = result;
if (replacement_has_nulls && result != -1) {
output_is_valid = replacement.is_valid_nocheck(result);
}
} else {
offsets.data<cudf::size_type>()[i] = 0;
indices.data<cudf::size_type>()[i] = -1;
offsets.data<cudf::size_type>()[idx] = 0;
indices.data<cudf::size_type>()[idx] = -1;
}

uint32_t bitmask = __ballot_sync(active_mask, output_is_valid);
if (0 == lane_id) {
output_valid[cudf::word_index(i)] = bitmask;
output_valid[cudf::word_index(idx)] = bitmask;
valid_sum += __popc(bitmask);
}

i += blockDim.x * gridDim.x;
active_mask = __ballot_sync(active_mask, i < nrows);
tid += stride;
active_mask = __ballot_sync(active_mask, tid < nrows);
}

// Compute total valid count for this block and add it to global count
Expand Down Expand Up @@ -189,27 +191,32 @@ __global__ void replace_strings_second_pass(cudf::column_device_view input,
cudf::mutable_column_device_view indices)
{
cudf::size_type nrows = input.size();
cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x;
auto tid = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();

while (i < nrows) {
bool output_is_valid = true;
bool input_is_valid = true;
cudf::size_type idx = indices.element<cudf::size_type>(i);
while (tid < nrows) {
auto const idx = static_cast<cudf::size_type>(tid);
auto const replace_idx = indices.element<cudf::size_type>(idx);
bool output_is_valid = true;
bool input_is_valid = true;

if (input_has_nulls) {
input_is_valid = input.is_valid_nocheck(i);
input_is_valid = input.is_valid_nocheck(idx);
output_is_valid = input_is_valid;
}
if (replacement_has_nulls && idx != -1) { output_is_valid = replacement.is_valid_nocheck(idx); }
if (replacement_has_nulls && replace_idx != -1) {
output_is_valid = replacement.is_valid_nocheck(replace_idx);
}
if (output_is_valid) {
cudf::string_view output = (idx == -1) ? input.element<cudf::string_view>(i)
: replacement.element<cudf::string_view>(idx);
std::memcpy(strings.data<char>() + offsets.data<cudf::size_type>()[i],
cudf::string_view output = (replace_idx == -1)
? input.element<cudf::string_view>(idx)
: replacement.element<cudf::string_view>(replace_idx);
std::memcpy(strings.data<char>() + offsets.data<cudf::size_type>()[idx],
output.data(),
output.size_bytes());
}

i += blockDim.x * gridDim.x;
tid += stride;
}
}

Expand Down Expand Up @@ -247,23 +254,25 @@ __global__ void replace_kernel(cudf::column_device_view input,
{
T* __restrict__ output_data = output.data<T>();

cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x;
auto tid = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();

uint32_t active_mask = 0xffff'ffffu;
active_mask = __ballot_sync(active_mask, i < nrows);
active_mask = __ballot_sync(active_mask, tid < nrows);
auto const lane_id{threadIdx.x % cudf::detail::warp_size};
uint32_t valid_sum{0};

while (i < nrows) {
while (tid < nrows) {
auto const idx = static_cast<cudf::size_type>(tid);
bool output_is_valid{true};
bool input_is_valid{true};
if (input_has_nulls) {
input_is_valid = input.is_valid_nocheck(i);
input_is_valid = input.is_valid_nocheck(idx);
output_is_valid = input_is_valid;
}
if (input_is_valid)
thrust::tie(output_data[i], output_is_valid) = get_new_value<T, replacement_has_nulls>(
i,
thrust::tie(output_data[idx], output_is_valid) = get_new_value<T, replacement_has_nulls>(
idx,
input.data<T>(),
values_to_replace.data<T>(),
values_to_replace.data<T>() + values_to_replace.size(),
Expand All @@ -274,13 +283,13 @@ __global__ void replace_kernel(cudf::column_device_view input,
if (input_has_nulls or replacement_has_nulls) {
uint32_t bitmask = __ballot_sync(active_mask, output_is_valid);
if (0 == lane_id) {
output.set_mask_word(cudf::word_index(i), bitmask);
output.set_mask_word(cudf::word_index(idx), bitmask);
valid_sum += __popc(bitmask);
}
}

i += blockDim.x * gridDim.x;
active_mask = __ballot_sync(active_mask, i < nrows);
tid += stride;
active_mask = __ballot_sync(active_mask, tid < nrows);
}
if (input_has_nulls or replacement_has_nulls) {
// Compute total valid count for this block and add it to global count
Expand Down Expand Up @@ -384,10 +393,16 @@ std::unique_ptr<cudf::column> replace_kernel_forwarder::operator()<cudf::string_
}

// Create new offsets column to use in kernel
std::unique_ptr<cudf::column> sizes = cudf::make_numeric_column(
cudf::data_type(cudf::type_id::INT32), input_col.size(), cudf::mask_state::UNALLOCATED, stream);
std::unique_ptr<cudf::column> indices = cudf::make_numeric_column(
cudf::data_type(cudf::type_id::INT32), input_col.size(), cudf::mask_state::UNALLOCATED, stream);
std::unique_ptr<cudf::column> sizes =
cudf::make_numeric_column(cudf::data_type{cudf::type_to_id<cudf::size_type>()},
input_col.size(),
cudf::mask_state::UNALLOCATED,
stream);
std::unique_ptr<cudf::column> indices =
cudf::make_numeric_column(cudf::data_type{cudf::type_to_id<cudf::size_type>()},
input_col.size(),
cudf::mask_state::UNALLOCATED,
stream);

auto sizes_view = sizes->mutable_view();
auto indices_view = indices->mutable_view();
Expand All @@ -413,7 +428,7 @@ std::unique_ptr<cudf::column> replace_kernel_forwarder::operator()<cudf::string_
valid_count);

auto [offsets, bytes] = cudf::detail::make_offsets_child_column(
sizes_view.begin<int32_t>(), sizes_view.end<int32_t>(), stream, mr);
sizes_view.begin<cudf::size_type>(), sizes_view.end<cudf::size_type>(), stream, mr);
auto offsets_view = offsets->mutable_view();
auto device_offsets = cudf::mutable_column_device_view::create(offsets_view, stream);

Expand Down

0 comments on commit 0c829cc

Please sign in to comment.