From 0c829cc0b868c288c3591771d555617d4d978ce3 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Fri, 1 Sep 2023 21:38:11 -0500 Subject: [PATCH] Use cudf::thread_index_type in replace.cu. (#13905) This PR uses `cudf::thread_index_type` in `replace.cu` to avoid risk of overflow. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Yunsong Wang (https://github.com/PointKernel) - Vukasin Milovanovic (https://github.com/vuule) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/13905 --- cpp/src/replace/replace.cu | 91 ++++++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/cpp/src/replace/replace.cu b/cpp/src/replace/replace.cu index 07eefdc27c6..9341929de44 100644 --- a/cpp/src/replace/replace.cu +++ b/cpp/src/replace/replace.cu @@ -127,40 +127,42 @@ __global__ void replace_strings_first_pass(cudf::column_device_view input, cudf::size_type* __restrict__ output_valid_count) { cudf::size_type nrows = input.size(); - cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); uint32_t active_mask = 0xffff'ffffu; - active_mask = __ballot_sync(active_mask, i < nrows); + active_mask = __ballot_sync(active_mask, tid < nrows); auto const lane_id{threadIdx.x % cudf::detail::warp_size}; uint32_t valid_sum{0}; - while (i < nrows) { + while (tid < nrows) { + auto const idx = static_cast(tid); bool input_is_valid = true; - if (input_has_nulls) input_is_valid = input.is_valid_nocheck(i); + if (input_has_nulls) input_is_valid = input.is_valid_nocheck(idx); bool output_is_valid = input_is_valid; if (input_is_valid) { - int result = get_new_string_value(i, input, values_to_replace, replacement); - cudf::string_view output = (result == -1) ? input.element(i) + int result = get_new_string_value(idx, input, values_to_replace, replacement); + cudf::string_view output = (result == -1) ? input.element(idx) : replacement.element(result); - offsets.data()[i] = output.size_bytes(); - indices.data()[i] = result; + offsets.data()[idx] = output.size_bytes(); + indices.data()[idx] = result; if (replacement_has_nulls && result != -1) { output_is_valid = replacement.is_valid_nocheck(result); } } else { - offsets.data()[i] = 0; - indices.data()[i] = -1; + offsets.data()[idx] = 0; + indices.data()[idx] = -1; } uint32_t bitmask = __ballot_sync(active_mask, output_is_valid); if (0 == lane_id) { - output_valid[cudf::word_index(i)] = bitmask; + output_valid[cudf::word_index(idx)] = bitmask; valid_sum += __popc(bitmask); } - i += blockDim.x * gridDim.x; - active_mask = __ballot_sync(active_mask, i < nrows); + tid += stride; + active_mask = __ballot_sync(active_mask, tid < nrows); } // Compute total valid count for this block and add it to global count @@ -189,27 +191,32 @@ __global__ void replace_strings_second_pass(cudf::column_device_view input, cudf::mutable_column_device_view indices) { cudf::size_type nrows = input.size(); - cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); - while (i < nrows) { - bool output_is_valid = true; - bool input_is_valid = true; - cudf::size_type idx = indices.element(i); + while (tid < nrows) { + auto const idx = static_cast(tid); + auto const replace_idx = indices.element(idx); + bool output_is_valid = true; + bool input_is_valid = true; if (input_has_nulls) { - input_is_valid = input.is_valid_nocheck(i); + input_is_valid = input.is_valid_nocheck(idx); output_is_valid = input_is_valid; } - if (replacement_has_nulls && idx != -1) { output_is_valid = replacement.is_valid_nocheck(idx); } + if (replacement_has_nulls && replace_idx != -1) { + output_is_valid = replacement.is_valid_nocheck(replace_idx); + } if (output_is_valid) { - cudf::string_view output = (idx == -1) ? input.element(i) - : replacement.element(idx); - std::memcpy(strings.data() + offsets.data()[i], + cudf::string_view output = (replace_idx == -1) + ? input.element(idx) + : replacement.element(replace_idx); + std::memcpy(strings.data() + offsets.data()[idx], output.data(), output.size_bytes()); } - i += blockDim.x * gridDim.x; + tid += stride; } } @@ -247,23 +254,25 @@ __global__ void replace_kernel(cudf::column_device_view input, { T* __restrict__ output_data = output.data(); - cudf::size_type i = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); uint32_t active_mask = 0xffff'ffffu; - active_mask = __ballot_sync(active_mask, i < nrows); + active_mask = __ballot_sync(active_mask, tid < nrows); auto const lane_id{threadIdx.x % cudf::detail::warp_size}; uint32_t valid_sum{0}; - while (i < nrows) { + while (tid < nrows) { + auto const idx = static_cast(tid); bool output_is_valid{true}; bool input_is_valid{true}; if (input_has_nulls) { - input_is_valid = input.is_valid_nocheck(i); + input_is_valid = input.is_valid_nocheck(idx); output_is_valid = input_is_valid; } if (input_is_valid) - thrust::tie(output_data[i], output_is_valid) = get_new_value( - i, + thrust::tie(output_data[idx], output_is_valid) = get_new_value( + idx, input.data(), values_to_replace.data(), values_to_replace.data() + values_to_replace.size(), @@ -274,13 +283,13 @@ __global__ void replace_kernel(cudf::column_device_view input, if (input_has_nulls or replacement_has_nulls) { uint32_t bitmask = __ballot_sync(active_mask, output_is_valid); if (0 == lane_id) { - output.set_mask_word(cudf::word_index(i), bitmask); + output.set_mask_word(cudf::word_index(idx), bitmask); valid_sum += __popc(bitmask); } } - i += blockDim.x * gridDim.x; - active_mask = __ballot_sync(active_mask, i < nrows); + tid += stride; + active_mask = __ballot_sync(active_mask, tid < nrows); } if (input_has_nulls or replacement_has_nulls) { // Compute total valid count for this block and add it to global count @@ -384,10 +393,16 @@ std::unique_ptr replace_kernel_forwarder::operator() sizes = cudf::make_numeric_column( - cudf::data_type(cudf::type_id::INT32), input_col.size(), cudf::mask_state::UNALLOCATED, stream); - std::unique_ptr indices = cudf::make_numeric_column( - cudf::data_type(cudf::type_id::INT32), input_col.size(), cudf::mask_state::UNALLOCATED, stream); + std::unique_ptr sizes = + cudf::make_numeric_column(cudf::data_type{cudf::type_to_id()}, + input_col.size(), + cudf::mask_state::UNALLOCATED, + stream); + std::unique_ptr indices = + cudf::make_numeric_column(cudf::data_type{cudf::type_to_id()}, + input_col.size(), + cudf::mask_state::UNALLOCATED, + stream); auto sizes_view = sizes->mutable_view(); auto indices_view = indices->mutable_view(); @@ -413,7 +428,7 @@ std::unique_ptr replace_kernel_forwarder::operator()(), sizes_view.end(), stream, mr); + sizes_view.begin(), sizes_view.end(), stream, mr); auto offsets_view = offsets->mutable_view(); auto device_offsets = cudf::mutable_column_device_view::create(offsets_view, stream);