From 0ed9af6f8b0ce28e8e45c0100c6435d62cf9b94d Mon Sep 17 00:00:00 2001 From: Thomas Li Date: Tue, 25 Jun 2024 19:27:14 +0000 Subject: [PATCH] Fix error in testing utils Co-authored-by: Lawrence Mitchell --- .../cudf/cudf/pylibcudf_tests/common/utils.py | 14 +++- .../cudf/cudf/pylibcudf_tests/test_copying.py | 68 +++++++------------ 2 files changed, 37 insertions(+), 45 deletions(-) diff --git a/python/cudf/cudf/pylibcudf_tests/common/utils.py b/python/cudf/cudf/pylibcudf_tests/common/utils.py index d172d48992f..4bc2911c4fe 100644 --- a/python/cudf/cudf/pylibcudf_tests/common/utils.py +++ b/python/cudf/cudf/pylibcudf_tests/common/utils.py @@ -15,7 +15,15 @@ def metadata_from_arrow_type( name: str = "", ) -> plc.interop.ColumnMetadata | None: metadata = plc.interop.ColumnMetadata(name) # None - if pa.types.is_list(pa_type) or pa.types.is_struct(pa_type): + if pa.types.is_list(pa_type): + child_meta = [plc.interop.ColumnMetadata("offsets")] + for i in range(pa_type.num_fields): + field_meta = metadata_from_arrow_type( + pa_type.field(i).type, pa_type.field(i).name + ) + child_meta.append(field_meta) + metadata = plc.interop.ColumnMetadata(name, child_meta) + elif pa.types.is_struct(pa_type): child_meta = [] for i in range(pa_type.num_fields): field_meta = metadata_from_arrow_type( @@ -57,8 +65,8 @@ def assert_column_eq( if isinstance(rhs, pa.ChunkedArray): rhs = rhs.combine_chunks() - # print(lhs) - # print(rhs) + print(lhs) + print(rhs) assert lhs.equals(rhs) diff --git a/python/cudf/cudf/pylibcudf_tests/test_copying.py b/python/cudf/cudf/pylibcudf_tests/test_copying.py index e527ad9f967..0a6df198d46 100644 --- a/python/cudf/cudf/pylibcudf_tests/test_copying.py +++ b/python/cudf/cudf/pylibcudf_tests/test_copying.py @@ -21,30 +21,6 @@ from cudf._lib import pylibcudf as plc -@pytest.fixture -def nested_list_skip(request): - """ - Fixture that xfails a test if we encounter a nested list. - (as of right now, we are encountering some segfaults/memoryerrors - in interop) - """ - if "target_table" in request.fixturenames: - pa_table, _ = request.getfixturevalue("target_table") - if any(is_nested_list(col.type) for col in pa_table.columns): - pytest.skip(reason="pylibcudf interop fails with nested list") - elif "target_column" or "input_column" in request.fixturenames: - if "target_column" in request.fixturenames: - pa_col, _ = request.getfixturevalue("target_column") - else: - pa_col, _ = request.getfixturevalue("input_column") - if is_nested_list(pa_col.type): - pytest.skip(reason="pylibcudf interop fails with nested list") - - -xfail_nested_struct = pytest.mark.usefixtures("nested_struct_xfail") -skip_nested_list = pytest.mark.usefixtures("nested_list_skip") - - # TODO: consider moving this to conftest and "pairing" # it with pa_type, so that they don't get out of sync # TODO: Test nullable data @@ -194,7 +170,6 @@ def mask(target_column): return pa_mask, plc.interop.from_arrow(pa_mask) -@skip_nested_list def test_gather(target_table, index_column): pa_target_table, plc_target_table = target_table pa_index_column, plc_index_column = index_column @@ -250,7 +225,6 @@ def _pyarrow_boolean_mask_scatter_table(source, mask, target_table): ) -@skip_nested_list def test_scatter_table( source_table, index_column, @@ -280,9 +254,17 @@ def test_scatter_table( ) if pa.types.is_list(dtype := pa_target_table[0].type): - expected = pa.table( - [pa.array([[4], [1], [2, 3], [3], [9], [10]])] * 3, [""] * 3 - ) + if is_nested_list(dtype): + expected = pa.table( + [pa.array([[[4]], [[1]], [[2, 3]], [[3]], [[9]], [[10]]])] + * 3, + [""] * 3, + ) + else: + expected = pa.table( + [pa.array([[4], [1], [2, 3], [3], [9], [10]])] * 3, + [""] * 3, + ) elif pa.types.is_struct(dtype): if is_nested_struct(dtype): expected = pa.table( @@ -392,7 +374,6 @@ def test_scatter_table_type_mismatch(source_table, index_column, target_table): ) -@skip_nested_list def test_scatter_scalars( source_scalar, index_column, @@ -670,7 +651,6 @@ def test_shift_type_mismatch(target_column): plc.copying.shift(plc_target_column, 2, fill_value) -@skip_nested_list def test_slice_column(target_column): pa_target_column, plc_target_column = target_column bounds = list(range(6)) @@ -699,7 +679,6 @@ def test_slice_column_out_of_bounds(target_column): plc.copying.slice(plc_target_column, list(range(2, 8))) -@skip_nested_list def test_slice_table(target_table): pa_target_table, plc_target_table = target_table bounds = list(range(6)) @@ -710,7 +689,6 @@ def test_slice_table(target_table): assert_table_eq(pa_target_table[lb:ub], slice_) -@skip_nested_list def test_split_column(target_column): upper_bounds = [1, 3, 5] lower_bounds = [0] + upper_bounds[:-1] @@ -732,7 +710,6 @@ def test_split_column_out_of_bounds(target_column): plc.copying.split(plc_target_column, list(range(5, 8))) -@skip_nested_list def test_split_table(target_table): pa_target_table, plc_target_table = target_table @@ -743,7 +720,6 @@ def test_split_table(target_table): assert_table_eq(pa_target_table[lb:ub], split) -@skip_nested_list def test_copy_if_else_column_column(target_column, mask, source_scalar): pa_target_column, plc_target_column = target_column pa_source_scalar, _ = source_scalar @@ -818,7 +794,6 @@ def test_copy_if_else_wrong_size_mask(target_column): ) -@skip_nested_list @pytest.mark.parametrize("array_left", [True, False]) def test_copy_if_else_column_scalar( target_column, @@ -852,7 +827,6 @@ def test_copy_if_else_column_scalar( assert_column_eq(expected, result) -@skip_nested_list def test_boolean_mask_scatter_from_table( source_table, target_table, @@ -879,9 +853,21 @@ def test_boolean_mask_scatter_from_table( ) if pa.types.is_list(dtype := pa_target_table[0].type): - expected = pa.table( - [pa.array([[1], [5, 6], [2, 3], [8], [3], [10]])] * 3, [""] * 3 - ) + if is_nested_list(dtype): + expected = pa.table( + [ + pa.array( + [[[1]], [[5, 6]], [[2, 3]], [[8]], [[3]], [[10]]] + ) + ] + * 3, + [""] * 3, + ) + else: + expected = pa.table( + [pa.array([[1], [5, 6], [2, 3], [8], [3], [10]])] * 3, + [""] * 3, + ) elif pa.types.is_struct(dtype): if is_nested_struct(dtype): expected = pa.table( @@ -989,7 +975,6 @@ def test_boolean_mask_scatter_from_wrong_mask_type(source_table, target_table): ) -@skip_nested_list def test_boolean_mask_scatter_from_scalars( source_scalar, target_table, @@ -1013,7 +998,6 @@ def test_boolean_mask_scatter_from_scalars( assert_table_eq(expected, result) -@skip_nested_list def test_get_element(input_column): index = 1 pa_input_column, plc_input_column = input_column