Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cudf::thread_index_type in replace.cu. #13905

Merged
merged 6 commits into from
Sep 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 53 additions & 38 deletions cpp/src/replace/replace.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,40 +127,42 @@ __global__ void replace_strings_first_pass(cudf::column_device_view input,
cudf::size_type* __restrict__ output_valid_count)
{
cudf::size_type nrows = input.size();
cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x;
auto tid = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
uint32_t active_mask = 0xffff'ffffu;
active_mask = __ballot_sync(active_mask, i < nrows);
active_mask = __ballot_sync(active_mask, tid < nrows);
auto const lane_id{threadIdx.x % cudf::detail::warp_size};
uint32_t valid_sum{0};

while (i < nrows) {
while (tid < nrows) {
auto const idx = static_cast<cudf::size_type>(tid);
bool input_is_valid = true;

if (input_has_nulls) input_is_valid = input.is_valid_nocheck(i);
if (input_has_nulls) input_is_valid = input.is_valid_nocheck(idx);
bool output_is_valid = input_is_valid;

if (input_is_valid) {
int result = get_new_string_value(i, input, values_to_replace, replacement);
cudf::string_view output = (result == -1) ? input.element<cudf::string_view>(i)
int result = get_new_string_value(idx, input, values_to_replace, replacement);
cudf::string_view output = (result == -1) ? input.element<cudf::string_view>(idx)
: replacement.element<cudf::string_view>(result);
offsets.data<cudf::size_type>()[i] = output.size_bytes();
indices.data<cudf::size_type>()[i] = result;
offsets.data<cudf::size_type>()[idx] = output.size_bytes();
indices.data<cudf::size_type>()[idx] = result;
if (replacement_has_nulls && result != -1) {
output_is_valid = replacement.is_valid_nocheck(result);
}
} else {
offsets.data<cudf::size_type>()[i] = 0;
indices.data<cudf::size_type>()[i] = -1;
offsets.data<cudf::size_type>()[idx] = 0;
indices.data<cudf::size_type>()[idx] = -1;
}

uint32_t bitmask = __ballot_sync(active_mask, output_is_valid);
if (0 == lane_id) {
output_valid[cudf::word_index(i)] = bitmask;
output_valid[cudf::word_index(idx)] = bitmask;
valid_sum += __popc(bitmask);
}

i += blockDim.x * gridDim.x;
active_mask = __ballot_sync(active_mask, i < nrows);
tid += stride;
active_mask = __ballot_sync(active_mask, tid < nrows);
}

// Compute total valid count for this block and add it to global count
Expand Down Expand Up @@ -189,27 +191,32 @@ __global__ void replace_strings_second_pass(cudf::column_device_view input,
cudf::mutable_column_device_view indices)
{
cudf::size_type nrows = input.size();
cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x;
auto tid = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();

while (i < nrows) {
bool output_is_valid = true;
bool input_is_valid = true;
cudf::size_type idx = indices.element<cudf::size_type>(i);
while (tid < nrows) {
auto const idx = static_cast<cudf::size_type>(tid);
auto const replace_idx = indices.element<cudf::size_type>(idx);
bool output_is_valid = true;
bool input_is_valid = true;

if (input_has_nulls) {
input_is_valid = input.is_valid_nocheck(i);
input_is_valid = input.is_valid_nocheck(idx);
output_is_valid = input_is_valid;
}
if (replacement_has_nulls && idx != -1) { output_is_valid = replacement.is_valid_nocheck(idx); }
if (replacement_has_nulls && replace_idx != -1) {
output_is_valid = replacement.is_valid_nocheck(replace_idx);
}
if (output_is_valid) {
cudf::string_view output = (idx == -1) ? input.element<cudf::string_view>(i)
: replacement.element<cudf::string_view>(idx);
std::memcpy(strings.data<char>() + offsets.data<cudf::size_type>()[i],
cudf::string_view output = (replace_idx == -1)
? input.element<cudf::string_view>(idx)
: replacement.element<cudf::string_view>(replace_idx);
std::memcpy(strings.data<char>() + offsets.data<cudf::size_type>()[idx],
output.data(),
output.size_bytes());
}

i += blockDim.x * gridDim.x;
tid += stride;
}
}

Expand Down Expand Up @@ -247,23 +254,25 @@ __global__ void replace_kernel(cudf::column_device_view input,
{
T* __restrict__ output_data = output.data<T>();

cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x;
auto tid = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();

uint32_t active_mask = 0xffff'ffffu;
active_mask = __ballot_sync(active_mask, i < nrows);
active_mask = __ballot_sync(active_mask, tid < nrows);
auto const lane_id{threadIdx.x % cudf::detail::warp_size};
uint32_t valid_sum{0};

while (i < nrows) {
while (tid < nrows) {
auto const idx = static_cast<cudf::size_type>(tid);
bool output_is_valid{true};
bool input_is_valid{true};
if (input_has_nulls) {
input_is_valid = input.is_valid_nocheck(i);
input_is_valid = input.is_valid_nocheck(idx);
output_is_valid = input_is_valid;
}
if (input_is_valid)
thrust::tie(output_data[i], output_is_valid) = get_new_value<T, replacement_has_nulls>(
i,
thrust::tie(output_data[idx], output_is_valid) = get_new_value<T, replacement_has_nulls>(
idx,
input.data<T>(),
values_to_replace.data<T>(),
values_to_replace.data<T>() + values_to_replace.size(),
Expand All @@ -274,13 +283,13 @@ __global__ void replace_kernel(cudf::column_device_view input,
if (input_has_nulls or replacement_has_nulls) {
uint32_t bitmask = __ballot_sync(active_mask, output_is_valid);
if (0 == lane_id) {
output.set_mask_word(cudf::word_index(i), bitmask);
output.set_mask_word(cudf::word_index(idx), bitmask);
valid_sum += __popc(bitmask);
}
}

i += blockDim.x * gridDim.x;
active_mask = __ballot_sync(active_mask, i < nrows);
tid += stride;
active_mask = __ballot_sync(active_mask, tid < nrows);
}
if (input_has_nulls or replacement_has_nulls) {
// Compute total valid count for this block and add it to global count
Expand Down Expand Up @@ -384,10 +393,16 @@ std::unique_ptr<cudf::column> replace_kernel_forwarder::operator()<cudf::string_
}

// Create new offsets column to use in kernel
std::unique_ptr<cudf::column> sizes = cudf::make_numeric_column(
cudf::data_type(cudf::type_id::INT32), input_col.size(), cudf::mask_state::UNALLOCATED, stream);
std::unique_ptr<cudf::column> indices = cudf::make_numeric_column(
cudf::data_type(cudf::type_id::INT32), input_col.size(), cudf::mask_state::UNALLOCATED, stream);
std::unique_ptr<cudf::column> sizes =
cudf::make_numeric_column(cudf::data_type{cudf::type_to_id<cudf::size_type>()},
input_col.size(),
cudf::mask_state::UNALLOCATED,
stream);
std::unique_ptr<cudf::column> indices =
cudf::make_numeric_column(cudf::data_type{cudf::type_to_id<cudf::size_type>()},
input_col.size(),
cudf::mask_state::UNALLOCATED,
stream);

auto sizes_view = sizes->mutable_view();
auto indices_view = indices->mutable_view();
Expand All @@ -413,7 +428,7 @@ std::unique_ptr<cudf::column> replace_kernel_forwarder::operator()<cudf::string_
valid_count);

auto [offsets, bytes] = cudf::detail::make_offsets_child_column(
sizes_view.begin<int32_t>(), sizes_view.end<int32_t>(), stream, mr);
sizes_view.begin<cudf::size_type>(), sizes_view.end<cudf::size_type>(), stream, mr);
auto offsets_view = offsets->mutable_view();
auto device_offsets = cudf::mutable_column_device_view::create(offsets_view, stream);

Expand Down