-
Notifications
You must be signed in to change notification settings - Fork 887
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add string.translate APIs to pylibcudf (#16934)
Contributes to #15162 Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: #16934
- Loading branch information
Showing
10 changed files
with
232 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
from pylibcudf.column cimport Column | ||
from pylibcudf.libcudf.strings.translate cimport filter_type | ||
from pylibcudf.scalar cimport Scalar | ||
|
||
|
||
cpdef Column translate(Column input, dict chars_table) | ||
|
||
cpdef Column filter_characters( | ||
Column input, | ||
dict characters_to_filter, | ||
filter_type keep_characters, | ||
Scalar replacement | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
from libcpp.memory cimport unique_ptr | ||
from libcpp.pair cimport pair | ||
from libcpp.utility cimport move | ||
from libcpp.vector cimport vector | ||
from pylibcudf.column cimport Column | ||
from pylibcudf.libcudf.column.column cimport column | ||
from pylibcudf.libcudf.scalar.scalar cimport string_scalar | ||
from pylibcudf.libcudf.strings cimport translate as cpp_translate | ||
from pylibcudf.libcudf.types cimport char_utf8 | ||
from pylibcudf.scalar cimport Scalar | ||
|
||
from cython.operator import dereference | ||
from pylibcudf.libcudf.strings.translate import \ | ||
filter_type as FilterType # no-cython-lint | ||
|
||
|
||
cdef vector[pair[char_utf8, char_utf8]] _table_to_c_table(dict table): | ||
""" | ||
Convert str.maketrans table to cudf compatible table. | ||
""" | ||
cdef int table_size = len(table) | ||
cdef vector[pair[char_utf8, char_utf8]] c_table | ||
|
||
c_table.reserve(table_size) | ||
for key, value in table.items(): | ||
if isinstance(value, int): | ||
value = chr(value) | ||
if isinstance(value, str): | ||
value = int.from_bytes(value.encode(), byteorder='big') | ||
if isinstance(key, int): | ||
key = chr(key) | ||
if isinstance(key, str): | ||
key = int.from_bytes(key.encode(), byteorder='big') | ||
c_table.push_back((key, value)) | ||
|
||
return c_table | ||
|
||
|
||
cpdef Column translate(Column input, dict chars_table): | ||
""" | ||
Translates individual characters within each string. | ||
For details, see :cpp:func:`cudf::strings::translate`. | ||
Parameters | ||
---------- | ||
input : Column | ||
Strings instance for this operation | ||
chars_table : dict | ||
Table of UTF-8 character mappings | ||
Returns | ||
------- | ||
Column | ||
New column with padded strings. | ||
""" | ||
cdef unique_ptr[column] c_result | ||
cdef vector[pair[char_utf8, char_utf8]] c_chars_table = _table_to_c_table( | ||
chars_table | ||
) | ||
|
||
with nogil: | ||
c_result = move( | ||
cpp_translate.translate( | ||
input.view(), | ||
c_chars_table | ||
) | ||
) | ||
return Column.from_libcudf(move(c_result)) | ||
|
||
|
||
cpdef Column filter_characters( | ||
Column input, | ||
dict characters_to_filter, | ||
filter_type keep_characters, | ||
Scalar replacement | ||
): | ||
""" | ||
Removes ranges of characters from each string in a strings column. | ||
For details, see :cpp:func:`cudf::strings::filter_characters`. | ||
Parameters | ||
---------- | ||
input : Column | ||
Strings instance for this operation | ||
characters_to_filter : dict | ||
Table of character ranges to filter on | ||
keep_characters : FilterType | ||
If true, the `characters_to_filter` are retained | ||
and all other characters are removed. | ||
replacement : Scalar | ||
Replacement string for each character removed. | ||
Returns | ||
------- | ||
Column | ||
New column with filtered strings. | ||
""" | ||
cdef unique_ptr[column] c_result | ||
cdef vector[pair[char_utf8, char_utf8]] c_characters_to_filter = _table_to_c_table( | ||
characters_to_filter | ||
) | ||
cdef const string_scalar* c_replacement = <const string_scalar*>( | ||
replacement.c_obj.get() | ||
) | ||
|
||
with nogil: | ||
c_result = move( | ||
cpp_translate.filter_characters( | ||
input.view(), | ||
c_characters_to_filter, | ||
keep_characters, | ||
dereference(c_replacement), | ||
) | ||
) | ||
return Column.from_libcudf(move(c_result)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
import pyarrow as pa | ||
import pylibcudf as plc | ||
import pytest | ||
from utils import assert_column_eq | ||
|
||
|
||
@pytest.fixture | ||
def data_col(): | ||
pa_data_col = pa.array( | ||
["aa", "bbb", "cccc", "abcd", None], | ||
type=pa.string(), | ||
) | ||
return pa_data_col, plc.interop.from_arrow(pa_data_col) | ||
|
||
|
||
@pytest.fixture | ||
def trans_table(): | ||
return str.maketrans("abd", "A Q") | ||
|
||
|
||
def test_translate(data_col, trans_table): | ||
pa_array, plc_col = data_col | ||
result = plc.strings.translate.translate(plc_col, trans_table) | ||
expected = pa.array( | ||
[ | ||
val.translate(trans_table) if isinstance(val, str) else None | ||
for val in pa_array.to_pylist() | ||
] | ||
) | ||
assert_column_eq(expected, result) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"keep", | ||
[ | ||
plc.strings.translate.FilterType.KEEP, | ||
plc.strings.translate.FilterType.REMOVE, | ||
], | ||
) | ||
def test_filter_characters(data_col, trans_table, keep): | ||
pa_array, plc_col = data_col | ||
result = plc.strings.translate.filter_characters( | ||
plc_col, trans_table, keep, plc.interop.from_arrow(pa.scalar("*")) | ||
) | ||
exp_data = [] | ||
flat_trans = set(trans_table.keys()).union(trans_table.values()) | ||
for val in pa_array.to_pylist(): | ||
if not isinstance(val, str): | ||
exp_data.append(val) | ||
else: | ||
new_val = "" | ||
for ch in val: | ||
if ( | ||
ch in flat_trans | ||
and keep == plc.strings.translate.FilterType.KEEP | ||
): | ||
new_val += ch | ||
elif ( | ||
ch not in flat_trans | ||
and keep == plc.strings.translate.FilterType.REMOVE | ||
): | ||
new_val += ch | ||
else: | ||
new_val += "*" | ||
exp_data.append(new_val) | ||
expected = pa.array(exp_data) | ||
assert_column_eq(expected, result) |