Skip to content

Commit

Permalink
Merge branch 'branch-24.12' into improve/xdist-worksteal-cudf-pandas
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 authored Oct 2, 2024
2 parents 7fec305 + bac81cb commit 2673f04
Show file tree
Hide file tree
Showing 43 changed files with 1,338 additions and 521 deletions.
78 changes: 64 additions & 14 deletions cpp/include/cudf/detail/utilities/cuda_memcpy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cudf/utilities/export.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

Expand All @@ -25,33 +26,82 @@ namespace detail {

enum class host_memory_kind : uint8_t { PINNED, PAGEABLE };

void cuda_memcpy_async_impl(
void* dst, void const* src, size_t size, host_memory_kind kind, rmm::cuda_stream_view stream);

/**
* @brief Asynchronously copies data between the host and device.
* @brief Asynchronously copies data from host to device memory.
*
* Implementation may use different strategies depending on the size and type of host data.
*
* @param dst Destination memory address
* @param src Source memory address
* @param size Number of bytes to copy
* @param kind Type of host memory
* @param dst Destination device memory
* @param src Source host memory
* @param stream CUDA stream used for the copy
*/
void cuda_memcpy_async(
void* dst, void const* src, size_t size, host_memory_kind kind, rmm::cuda_stream_view stream);
template <typename T>
void cuda_memcpy_async(device_span<T> dst, host_span<T const> src, rmm::cuda_stream_view stream)
{
CUDF_EXPECTS(dst.size() == src.size(), "Mismatched sizes in cuda_memcpy_async");
auto const is_pinned = src.is_device_accessible();
cuda_memcpy_async_impl(dst.data(),
src.data(),
src.size_bytes(),
is_pinned ? host_memory_kind::PINNED : host_memory_kind::PAGEABLE,
stream);
}

/**
* @brief Synchronously copies data between the host and device.
* @brief Asynchronously copies data from device to host memory.
*
* Implementation may use different strategies depending on the size and type of host data.
*
* @param dst Destination memory address
* @param src Source memory address
* @param size Number of bytes to copy
* @param kind Type of host memory
* @param dst Destination host memory
* @param src Source device memory
* @param stream CUDA stream used for the copy
*/
void cuda_memcpy(
void* dst, void const* src, size_t size, host_memory_kind kind, rmm::cuda_stream_view stream);
template <typename T>
void cuda_memcpy_async(host_span<T> dst, device_span<T const> src, rmm::cuda_stream_view stream)
{
CUDF_EXPECTS(dst.size() == src.size(), "Mismatched sizes in cuda_memcpy_async");
auto const is_pinned = dst.is_device_accessible();
cuda_memcpy_async_impl(dst.data(),
src.data(),
src.size_bytes(),
is_pinned ? host_memory_kind::PINNED : host_memory_kind::PAGEABLE,
stream);
}

/**
* @brief Synchronously copies data from host to device memory.
*
* Implementation may use different strategies depending on the size and type of host data.
*
* @param dst Destination device memory
* @param src Source host memory
* @param stream CUDA stream used for the copy
*/
template <typename T>
void cuda_memcpy(device_span<T> dst, host_span<T const> src, rmm::cuda_stream_view stream)
{
cuda_memcpy_async(dst, src, stream);
stream.synchronize();
}

/**
* @brief Synchronously copies data from device to host memory.
*
* Implementation may use different strategies depending on the size and type of host data.
*
* @param dst Destination host memory
* @param src Source device memory
* @param stream CUDA stream used for the copy
*/
template <typename T>
void cuda_memcpy(host_span<T> dst, device_span<T const> src, rmm::cuda_stream_view stream)
{
cuda_memcpy_async(dst, src, stream);
stream.synchronize();
}

} // namespace detail
} // namespace CUDF_EXPORT cudf
16 changes: 3 additions & 13 deletions cpp/include/cudf/detail/utilities/vector_factories.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,7 @@ rmm::device_uvector<T> make_device_uvector_async(host_span<T const> source_data,
rmm::device_async_resource_ref mr)
{
rmm::device_uvector<T> ret(source_data.size(), stream, mr);
auto const is_pinned = source_data.is_device_accessible();
cuda_memcpy_async(ret.data(),
source_data.data(),
source_data.size() * sizeof(T),
is_pinned ? host_memory_kind::PINNED : host_memory_kind::PAGEABLE,
stream);
cuda_memcpy_async<T>(ret, source_data, stream);
return ret;
}

Expand Down Expand Up @@ -405,13 +400,8 @@ host_vector<T> make_empty_host_vector(size_t capacity, rmm::cuda_stream_view str
template <typename T>
host_vector<T> make_host_vector_async(device_span<T const> v, rmm::cuda_stream_view stream)
{
auto result = make_host_vector<T>(v.size(), stream);
auto const is_pinned = result.get_allocator().is_device_accessible();
cuda_memcpy_async(result.data(),
v.data(),
v.size() * sizeof(T),
is_pinned ? host_memory_kind::PINNED : host_memory_kind::PAGEABLE,
stream);
auto result = make_host_vector<T>(v.size(), stream);
cuda_memcpy_async<T>(result, v, stream);
return result;
}

Expand Down
5 changes: 3 additions & 2 deletions cpp/include/cudf/strings/char_types/char_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace strings {
*/

/**
* @brief Returns a boolean column identifying strings entries in which all
* @brief Returns a boolean column identifying string entries where all
* characters are of the type specified.
*
* The output row entry will be set to false if the corresponding string element
Expand Down Expand Up @@ -105,7 +105,8 @@ std::unique_ptr<column> all_characters_of_type(
* `types_to_remove` will be filtered.
* @param mr Device memory resource used to allocate the returned column's device memory
* @param stream CUDA stream used for device memory operations and kernel launches
* @return New column of boolean results for each string
* @return New strings column with the characters of specified types filtered out and replaced by
* the specified replacement string
*/
std::unique_ptr<column> filter_characters_of_type(
strings_column_view const& input,
Expand Down
13 changes: 3 additions & 10 deletions cpp/src/io/json/host_tree_algorithms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -634,11 +634,8 @@ std::pair<cudf::detail::host_vector<bool>, hashmap_of_device_columns> build_tree
is_mixed_type_column[this_col_id] == 1)
column_categories[this_col_id] = NC_STR;
}
cudf::detail::cuda_memcpy_async(d_column_tree.node_categories.begin(),
column_categories.data(),
column_categories.size() * sizeof(column_categories[0]),
cudf::detail::host_memory_kind::PAGEABLE,
stream);
cudf::detail::cuda_memcpy_async<NodeT>(
d_column_tree.node_categories, column_categories, stream);
}

// ignore all children of columns forced as string
Expand All @@ -653,11 +650,7 @@ std::pair<cudf::detail::host_vector<bool>, hashmap_of_device_columns> build_tree
forced_as_string_column[this_col_id])
column_categories[this_col_id] = NC_STR;
}
cudf::detail::cuda_memcpy_async(d_column_tree.node_categories.begin(),
column_categories.data(),
column_categories.size() * sizeof(column_categories[0]),
cudf::detail::host_memory_kind::PAGEABLE,
stream);
cudf::detail::cuda_memcpy_async<NodeT>(d_column_tree.node_categories, column_categories, stream);

// restore unique_col_ids order
std::sort(h_range_col_id_it, h_range_col_id_it + num_columns, [](auto const& a, auto const& b) {
Expand Down
14 changes: 4 additions & 10 deletions cpp/src/io/utilities/hostdevice_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,23 +125,17 @@ class hostdevice_vector {

void host_to_device_async(rmm::cuda_stream_view stream)
{
cuda_memcpy_async(device_ptr(), host_ptr(), size_bytes(), host_memory_kind::PINNED, stream);
cuda_memcpy_async<T>(d_data, h_data, stream);
}

void host_to_device_sync(rmm::cuda_stream_view stream)
{
cuda_memcpy(device_ptr(), host_ptr(), size_bytes(), host_memory_kind::PINNED, stream);
}
void host_to_device_sync(rmm::cuda_stream_view stream) { cuda_memcpy<T>(d_data, h_data, stream); }

void device_to_host_async(rmm::cuda_stream_view stream)
{
cuda_memcpy_async(host_ptr(), device_ptr(), size_bytes(), host_memory_kind::PINNED, stream);
cuda_memcpy_async<T>(h_data, d_data, stream);
}

void device_to_host_sync(rmm::cuda_stream_view stream)
{
cuda_memcpy(host_ptr(), device_ptr(), size_bytes(), host_memory_kind::PINNED, stream);
}
void device_to_host_sync(rmm::cuda_stream_view stream) { cuda_memcpy<T>(h_data, d_data, stream); }

/**
* @brief Converts a hostdevice_vector into a hostdevice_span.
Expand Down
11 changes: 2 additions & 9 deletions cpp/src/utilities/cuda_memcpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace cudf::detail {
namespace {

// Simple kernel to copy between device buffers
CUDF_KERNEL void copy_kernel(char const* src, char* dst, size_t n)
CUDF_KERNEL void copy_kernel(char const* __restrict__ src, char* __restrict__ dst, size_t n)
{
auto const idx = cudf::detail::grid_1d::global_thread_id();
if (idx < n) { dst[idx] = src[idx]; }
Expand Down Expand Up @@ -61,7 +61,7 @@ void copy_pageable(void* dst, void const* src, std::size_t size, rmm::cuda_strea

}; // namespace

void cuda_memcpy_async(
void cuda_memcpy_async_impl(
void* dst, void const* src, size_t size, host_memory_kind kind, rmm::cuda_stream_view stream)
{
if (kind == host_memory_kind::PINNED) {
Expand All @@ -73,11 +73,4 @@ void cuda_memcpy_async(
}
}

void cuda_memcpy(
void* dst, void const* src, size_t size, host_memory_kind kind, rmm::cuda_stream_view stream)
{
cuda_memcpy_async(dst, src, size, kind, stream);
stream.synchronize();
}

} // namespace cudf::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
=============
find_multiple
=============

.. automodule:: pylibcudf.strings.find_multiple
:members:
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ strings
contains
extract
find
find_multiple
findall
regex_flags
regex_program
repeat
replace
slice
split
strip
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
=====
split
=====

.. automodule:: pylibcudf.strings.split
:members:
Loading

0 comments on commit 2673f04

Please sign in to comment.