Skip to content

Commit

Permalink
Raise error for string types in nsmallest and nlargest (#13946)
Browse files Browse the repository at this point in the history
closes #13945 

This PR contains changes that raises an error message exactly matching pandas for `nsmallest` and `nlargest`.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #13946
  • Loading branch information
galipremsagar authored Aug 24, 2023
1 parent 6ed42d7 commit 83f9cbf
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,20 @@ def _n_largest_or_smallest(self, largest, n, columns, keep):
if isinstance(columns, str):
columns = [columns]

method = "nlargest" if largest else "nsmallest"
for col in columns:
if isinstance(self._data[col], cudf.core.column.StringColumn):
if isinstance(self, cudf.DataFrame):
error_msg = (
f"Column '{col}' has dtype {self._data[col].dtype}, "
f"cannot use method '{method}' with this dtype"
)
else:
error_msg = (
f"Cannot use method '{method}' with "
f"dtype {self._data[col].dtype}"
)
raise TypeError(error_msg)
if len(self) == 0:
return self

Expand Down
13 changes: 13 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10316,3 +10316,16 @@ def test_dataframe_reindex_with_index_names(index_data, name):
expected = pdf.reindex(index_data)

assert_eq(actual, expected)


@pytest.mark.parametrize("attr", ["nlargest", "nsmallest"])
def test_dataframe_nlargest_nsmallest_str_error(attr):
gdf = cudf.DataFrame({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "d"]})
pdf = gdf.to_pandas()

assert_exceptions_equal(
getattr(gdf, attr),
getattr(pdf, attr),
([], {"n": 1, "columns": ["a", "b"]}),
([], {"n": 1, "columns": ["a", "b"]}),
)
10 changes: 10 additions & 0 deletions python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,3 +2244,13 @@ def test_series_typecast_to_object():
assert new_series[0] == "1970-01-01 00:00:00.000000001"
new_series = actual.astype(np.dtype("object"))
assert new_series[0] == "1970-01-01 00:00:00.000000001"


@pytest.mark.parametrize("attr", ["nlargest", "nsmallest"])
def test_series_nlargest_nsmallest_str_error(attr):
gs = cudf.Series(["a", "b", "c", "d", "e"])
ps = gs.to_pandas()

assert_exceptions_equal(
getattr(gs, attr), getattr(ps, attr), ([], {"n": 1}), ([], {"n": 1})
)

0 comments on commit 83f9cbf

Please sign in to comment.