Skip to content

Commit

Permalink
update following feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
lithomas1 committed Jun 13, 2024
1 parent b1951d0 commit 1228569
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 54 deletions.
13 changes: 5 additions & 8 deletions python/cudf/cudf/_lib/json.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,10 @@ cdef data_type _get_cudf_data_type_from_dtype(object dtype) except *:


def _dtype_to_names_list(col):
cdef list child_names = []
if isinstance(col.dtype, cudf.StructDtype):
for child_col, name in zip(col.children, list(col.dtype.fields)):
child_names.append((name, _dtype_to_names_list(child_col)))
return [(name, _dtype_to_names_list(child))
for name, child in zip(col.dtype.fields, col.children)]
elif isinstance(col.dtype, cudf.ListDtype):
for child_col in col.children:
list_child_names = _dtype_to_names_list(child_col)
child_names.append(("", list_child_names))

return child_names
return [("", _dtype_to_names_list(child))
for child in col.children]
return []
8 changes: 6 additions & 2 deletions python/cudf/cudf/_lib/pylibcudf/io/types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ cdef class TableWithMetadata:
Parameters
----------
tbl: Table
tbl : Table
The input table.
column_names: list
column_names : list
A list of tuples each containing the name of each column
and the names of its child columns (in the same format).
e.g.
Expand Down Expand Up @@ -193,6 +193,10 @@ cdef class SinkInfo:
cdef unique_ptr[data_sink] sink

cdef vector[string] paths

if not sinks:
raise ValueError("Need to pass at least one sink")

if isinstance(sinks[0], io.StringIO):
data_sinks.reserve(len(sinks))
for s in sinks:
Expand Down
1 change: 1 addition & 0 deletions python/cudf/cudf/_lib/pylibcudf/libcudf/io/types.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ cdef extern from "cudf/io/types.hpp" \
vector[column_name_info] children

cdef cppclass table_metadata:
table_metadata() except +

vector[string] column_names
map[string, string] user_data
Expand Down
31 changes: 24 additions & 7 deletions python/cudf/cudf/pylibcudf_tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,32 @@ def is_string(plc_dtype: plc.DataType):
def is_fixed_width(plc_dtype: plc.DataType):
return (
is_integer(plc_dtype)
or is_unsigned_integer(plc_dtype)
or is_floating(plc_dtype)
or is_boolean(plc_dtype)
)


def is_nested_struct(pa_type: pa.DataType):
if isinstance(pa_type, pa.StructType):
for i in range(pa_type.num_fields):
if isinstance(pa_type[i].type, pa.StructType):
return True
return False


def is_nested_list(pa_type: pa.DataType):
if isinstance(pa_type, pa.ListType):
return isinstance(pa_type.value_type, pa.ListType)
return False


def sink_to_str(sink):
"""
Takes a sink (e.g. StringIO/BytesIO, filepath, etc.)
and reads in the contents into a string (str not bytes)
for comparison
"""
if isinstance(sink, (str, os.PathLike)):
with open(sink, "r") as f:
str_result = f.read()
Expand All @@ -151,8 +171,7 @@ def sink_to_str(sink):
return str_result


# TODO: enable uint64, some failing tests
NUMERIC_PA_TYPES = [pa.int64(), pa.float64()] # pa.uint64()]
NUMERIC_PA_TYPES = [pa.int64(), pa.float64(), pa.uint64()]
STRING_PA_TYPES = [pa.string()]
BOOL_PA_TYPES = [pa.bool_()]
LIST_PA_TYPES = [
Expand Down Expand Up @@ -187,10 +206,8 @@ def sink_to_str(sink):
+ BOOL_PA_TYPES
# exclude nested list/struct cases
# since not all tests work with them yet
+ LIST_PA_TYPES[:1]
+ DEFAULT_PA_STRUCT_TESTING_TYPES[:1]
+ LIST_PA_TYPES # [:1]
+ DEFAULT_PA_STRUCT_TESTING_TYPES # [:1]
)

ALL_PA_TYPES = (
DEFAULT_PA_TYPES + LIST_PA_TYPES[1:] + DEFAULT_PA_STRUCT_TESTING_TYPES[1:]
)
ALL_PA_TYPES = DEFAULT_PA_TYPES # + LIST_PA_TYPES[1:] + DEFAULT_PA_STRUCT_TESTING_TYPES[1:]
2 changes: 2 additions & 0 deletions python/cudf/cudf/pylibcudf_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def table_data(request):
# plc.io.TableWithMetadata
colnames = []

np.random.seed(42)

for typ in ALL_PA_TYPES:
rand_vals = np.random.randint(0, nrows, nrows)
child_colnames = []
Expand Down
Loading

0 comments on commit 1228569

Please sign in to comment.