Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure cudf objects can astype to any type when empty #16106

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,15 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool:
raise NotImplementedError()

def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
if len(self) == 0:
dtype = cudf.dtype(dtype)
if self.dtype == dtype:
if copy:
return self.copy()
else:
return self
else:
return column_empty(0, dtype=dtype, masked=self.nullable)
if copy:
col = self.copy()
else:
Expand Down
36 changes: 19 additions & 17 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def __contains__(self, item: ScalarLike) -> bool:
return False
elif ts.tzinfo is not None:
ts = ts.tz_convert(None)
return ts.to_numpy().astype("int64") in self.as_numerical_column(
"int64"
return ts.to_numpy().astype("int64") in cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
)

@functools.cached_property
Expand Down Expand Up @@ -503,9 +503,9 @@ def mean(
self, skipna=None, min_count: int = 0, dtype=np.float64
) -> ScalarLike:
return pd.Timestamp(
self.as_numerical_column("int64").mean(
skipna=skipna, min_count=min_count, dtype=dtype
),
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).mean(skipna=skipna, min_count=min_count, dtype=dtype),
unit=self.time_unit,
).as_unit(self.time_unit)

Expand All @@ -517,15 +517,17 @@ def std(
ddof: int = 1,
) -> pd.Timedelta:
return pd.Timedelta(
self.as_numerical_column("int64").std(
cast("cudf.core.column.NumericalColumn", self.astype("int64")).std(
skipna=skipna, min_count=min_count, dtype=dtype, ddof=ddof
)
* _unit_to_nanoseconds_conversion[self.time_unit],
).as_unit(self.time_unit)

def median(self, skipna: bool | None = None) -> pd.Timestamp:
return pd.Timestamp(
self.as_numerical_column("int64").median(skipna=skipna),
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).median(skipna=skipna),
unit=self.time_unit,
).as_unit(self.time_unit)

Expand All @@ -534,18 +536,18 @@ def cov(self, other: DatetimeColumn) -> float:
raise TypeError(
f"cannot perform cov with types {self.dtype}, {other.dtype}"
)
return self.as_numerical_column("int64").cov(
other.as_numerical_column("int64")
)
return cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).cov(cast("cudf.core.column.NumericalColumn", other.astype("int64")))

def corr(self, other: DatetimeColumn) -> float:
if not isinstance(other, DatetimeColumn):
raise TypeError(
f"cannot perform corr with types {self.dtype}, {other.dtype}"
)
return self.as_numerical_column("int64").corr(
other.as_numerical_column("int64")
)
return cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).corr(cast("cudf.core.column.NumericalColumn", other.astype("int64")))

def quantile(
self,
Expand All @@ -554,7 +556,7 @@ def quantile(
exact: bool,
return_scalar: bool,
) -> ColumnBase:
result = self.as_numerical_column("int64").quantile(
result = self.astype("int64").quantile(
q=q,
interpolation=interpolation,
exact=exact,
Expand Down Expand Up @@ -645,12 +647,12 @@ def indices_of(
) -> cudf.core.column.NumericalColumn:
value = column.as_column(
pd.to_datetime(value), dtype=self.dtype
).as_numerical_column("int64")
return self.as_numerical_column("int64").indices_of(value)
).astype("int64")
return self.astype("int64").indices_of(value)

@property
def is_unique(self) -> bool:
return self.as_numerical_column("int64").is_unique
return self.astype("int64").is_unique

def isin(self, values: Sequence) -> ColumnBase:
return cudf.core.tools.datetimes._isin_datetimelike(self, values)
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def normalize_binop_value(self, other):
"Decimal columns only support binary operations with "
"integer numerical columns."
)
other = other.as_decimal_column(
other = other.astype(
self.dtype.__class__(self.dtype.__class__.MAX_PRECISION, 0)
)
elif not isinstance(other, DecimalBaseColumn):
Expand Down
26 changes: 11 additions & 15 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import cudf
from cudf.core.column import StructColumn
from cudf.core.dtypes import CategoricalDtype, IntervalDtype
from cudf.core.dtypes import IntervalDtype


class IntervalColumn(StructColumn):
Expand Down Expand Up @@ -87,20 +87,16 @@ def copy(self, deep=True):

def as_interval_column(self, dtype):
if isinstance(dtype, IntervalDtype):
if isinstance(self.dtype, CategoricalDtype):
new_struct = self._get_decategorized_column()
return IntervalColumn.from_struct_column(new_struct)
else:
return IntervalColumn(
size=self.size,
dtype=dtype,
mask=self.mask,
offset=self.offset,
null_count=self.null_count,
children=tuple(
child.astype(dtype.subtype) for child in self.children
),
)
return IntervalColumn(
size=self.size,
dtype=dtype,
mask=self.mask,
offset=self.offset,
null_count=self.null_count,
children=tuple(
child.astype(dtype.subtype) for child in self.children
),
)
else:
raise ValueError("dtype must be IntervalDtype")

Expand Down
34 changes: 19 additions & 15 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __contains__(self, item: DatetimeLikeScalar) -> bool:
# np.timedelta64 raises ValueError, hence `item`
# cannot exist in `self`.
return False
return item.view("int64") in self.as_numerical_column("int64")
return item.view("int64") in cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
)

@property
def values(self):
Expand All @@ -132,9 +134,7 @@ def to_arrow(self) -> pa.Array:
self.mask_array_view(mode="read").copy_to_host()
)
data = pa.py_buffer(
self.as_numerical_column("int64")
.data_array_view(mode="read")
.copy_to_host()
self.astype("int64").data_array_view(mode="read").copy_to_host()
)
pa_dtype = np_to_pa_dtype(self.dtype)
return pa.Array.from_buffers(
Expand Down Expand Up @@ -295,13 +295,17 @@ def as_timedelta_column(

def mean(self, skipna=None, dtype: Dtype = np.float64) -> pd.Timedelta:
return pd.Timedelta(
self.as_numerical_column("int64").mean(skipna=skipna, dtype=dtype),
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).mean(skipna=skipna, dtype=dtype),
unit=self.time_unit,
).as_unit(self.time_unit)

def median(self, skipna: bool | None = None) -> pd.Timedelta:
return pd.Timedelta(
self.as_numerical_column("int64").median(skipna=skipna),
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).median(skipna=skipna),
unit=self.time_unit,
).as_unit(self.time_unit)

Expand All @@ -315,7 +319,7 @@ def quantile(
exact: bool,
return_scalar: bool,
) -> ColumnBase:
result = self.as_numerical_column("int64").quantile(
result = self.astype("int64").quantile(
q=q,
interpolation=interpolation,
exact=exact,
Expand All @@ -337,7 +341,7 @@ def sum(
# Since sum isn't overridden in Numerical[Base]Column, mypy only
# sees the signature from Reducible (which doesn't have the extra
# parameters from ColumnBase._reduce) so we have to ignore this.
self.as_numerical_column("int64").sum( # type: ignore
self.astype("int64").sum( # type: ignore
skipna=skipna, min_count=min_count, dtype=dtype
),
unit=self.time_unit,
Expand All @@ -351,7 +355,7 @@ def std(
ddof: int = 1,
) -> pd.Timedelta:
return pd.Timedelta(
self.as_numerical_column("int64").std(
cast("cudf.core.column.NumericalColumn", self.astype("int64")).std(
skipna=skipna, min_count=min_count, ddof=ddof, dtype=dtype
),
unit=self.time_unit,
Expand All @@ -362,18 +366,18 @@ def cov(self, other: TimeDeltaColumn) -> float:
raise TypeError(
f"cannot perform cov with types {self.dtype}, {other.dtype}"
)
return self.as_numerical_column("int64").cov(
other.as_numerical_column("int64")
)
return cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).cov(cast("cudf.core.column.NumericalColumn", other.astype("int64")))

def corr(self, other: TimeDeltaColumn) -> float:
if not isinstance(other, TimeDeltaColumn):
raise TypeError(
f"cannot perform corr with types {self.dtype}, {other.dtype}"
)
return self.as_numerical_column("int64").corr(
other.as_numerical_column("int64")
)
return cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).corr(cast("cudf.core.column.NumericalColumn", other.astype("int64")))

def components(self) -> dict[str, ColumnBase]:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,7 +2404,7 @@ def scatter_by_map(
if isinstance(map_index, cudf.core.column.StringColumn):
cat_index = cast(
cudf.core.column.CategoricalColumn,
map_index.as_categorical_column("category"),
map_index.astype("category"),
)
map_index = cat_index.codes
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ def to_pandas(self) -> pd.IntervalDtype:
def __eq__(self, other):
if isinstance(other, str):
# This means equality isn't transitive but mimics pandas
return other == self.name
return other in (self.name, str(self))
return (
type(self) == type(other)
and self.subtype == other.subtype
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def from_arrow(cls, data: pa.Table) -> Self:
# of column is 0 (i.e., empty) then we will have an
# int8 column in result._data[name] returned by libcudf,
# which needs to be type-casted to 'category' dtype.
result[name] = result[name].as_categorical_column("category")
result[name] = result[name].astype("category")
elif (
pandas_dtypes.get(name) == "empty"
and np_dtypes.get(name) == "object"
Expand All @@ -936,7 +936,7 @@ def from_arrow(cls, data: pa.Table) -> Self:
# is specified as 'empty' and np_dtypes as 'object',
# hence handling this special case to type-cast the empty
# float column to str column.
result[name] = result[name].as_string_column(cudf.dtype("str"))
result[name] = result[name].astype(cudf.dtype("str"))
elif name in data.column_names and isinstance(
data[name].type,
(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/indexing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def parse_row_iloc_indexer(key: Any, n: int) -> IndexingSpec:
else:
key = cudf.core.column.as_column(key)
if isinstance(key, cudf.core.column.CategoricalColumn):
key = key.as_numerical_column(key.codes.dtype)
key = key.astype(key.codes.dtype)
if is_bool_dtype(key.dtype):
return MaskIndexer(BooleanMask(key, n))
elif len(key) == 0:
Expand Down
8 changes: 5 additions & 3 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3107,10 +3107,12 @@ def value_counts(
# Pandas returns an IntervalIndex as the index of res
# this condition makes sure we do too if bins is given
if bins is not None and len(res) == len(res.index.categories):
int_index = IntervalColumn.as_interval_column(
res.index._column, res.index.categories.dtype
interval_col = IntervalColumn.from_struct_column(
res.index._column._get_decategorized_column()
)
res.index = cudf.IntervalIndex._from_data(
{res.index.name: interval_col}
)
res.index = int_index
res.name = result_name
return res

Expand Down
14 changes: 7 additions & 7 deletions python/cudf/cudf/core/tools/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def to_numeric(arg, errors="raise", downcast=None):
dtype = col.dtype

if is_datetime_dtype(dtype) or is_timedelta_dtype(dtype):
col = col.as_numerical_column(cudf.dtype("int64"))
col = col.astype(cudf.dtype("int64"))
elif isinstance(dtype, CategoricalDtype):
cat_dtype = col.dtype.type
if _is_non_decimal_numeric_dtype(cat_dtype):
col = col.as_numerical_column(cat_dtype)
col = col.astype(cat_dtype)
else:
try:
col = _convert_str_col(
Expand All @@ -146,8 +146,8 @@ def to_numeric(arg, errors="raise", downcast=None):
raise ValueError("Unrecognized datatype")

# str->float conversion may require lower precision
if col.dtype == cudf.dtype("f"):
col = col.as_numerical_column("d")
if col.dtype == cudf.dtype("float32"):
col = col.astype("float64")

if downcast:
if downcast == "float":
Expand Down Expand Up @@ -205,7 +205,7 @@ def _convert_str_col(col, errors, _downcast=None):

is_integer = libstrings.is_integer(col)
if is_integer.all():
return col.as_numerical_column(dtype=cudf.dtype("i8"))
return col.astype(dtype=cudf.dtype("i8"))

col = _proc_inf_empty_strings(col)

Expand All @@ -218,9 +218,9 @@ def _convert_str_col(col, errors, _downcast=None):
"limited by float32 precision."
)
)
return col.as_numerical_column(dtype=cudf.dtype("f"))
return col.astype(dtype=cudf.dtype("float32"))
else:
return col.as_numerical_column(dtype=cudf.dtype("d"))
return col.astype(dtype=cudf.dtype("float64"))
else:
if errors == "coerce":
col = libcudf.string_casting.stod(col)
Expand Down
6 changes: 6 additions & 0 deletions python/cudf/cudf/tests/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,9 @@ def test_from_pandas_intervaldtype():
result = cudf.from_pandas(dtype)
expected = cudf.IntervalDtype("int64", closed="left")
assert_eq(result, expected)


def test_intervaldtype_eq_string_with_attributes():
dtype = cudf.IntervalDtype("int64", closed="left")
assert dtype == "interval"
assert dtype == "interval[int64, left]"
Loading
Loading