Skip to content

Commit

Permalink
Flesh out more container stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed May 9, 2024
1 parent 8ec9471 commit 3d7e9d8
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 12 deletions.
6 changes: 5 additions & 1 deletion python/cudf_polars/cudf_polars/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@

from __future__ import annotations

__all__: list[str] = []
__all__: list[str] = ["DataFrame", "Column", "Scalar"]

from cudf_polars.containers.column import Column
from cudf_polars.containers.dataframe import DataFrame
from cudf_polars.containers.scalar import Scalar
28 changes: 20 additions & 8 deletions python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import cudf._lib.pylibcudf as plc

if TYPE_CHECKING:
from typing_extensions import Self

__all__: list[str] = ["Column"]


Expand All @@ -25,14 +30,22 @@ def __init__(self, column: plc.Column, name: str):
self.name = name
self.is_sorted = plc.types.Sorted.NO

def with_metadata(self, *, like: Column) -> Self:
"""Copy metadata from a column onto self."""
self.is_sorted = like.is_sorted
self.order = like.order
self.null_order = like.null_order
return self

def set_sorted(
self,
*,
is_sorted: plc.types.Sorted,
order: plc.types.Order,
null_order: plc.types.NullOrder,
) -> Column:
) -> Self:
"""
Return a new column sharing data with sortedness set.
Modify sortedness metadata in place.
Parameters
----------
Expand All @@ -45,10 +58,9 @@ def set_sorted(
Returns
-------
New column sharing data.
Self with metadata set.
"""
obj = Column(self.obj, self.name)
obj.is_sorted = is_sorted
obj.order = order
obj.null_order = null_order
return obj
self.is_sorted = is_sorted
self.order = order
self.null_order = null_order
return self
73 changes: 70 additions & 3 deletions python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
from __future__ import annotations

import itertools
from functools import cached_property
from typing import TYPE_CHECKING

import cudf._lib.pylibcudf as plc

from cudf_polars.containers.column import Column
from cudf_polars.containers.scalar import Scalar

if TYPE_CHECKING:
from cudf_polars.containers.column import Column
from cudf_polars.containers.scalar import Scalar
from typing_extensions import Self

import cudf


__all__: list[str] = ["DataFrame"]

Expand All @@ -35,7 +41,7 @@ def __init__(self, columns: list[Column], scalars: list[Scalar]) -> None:
self.columns = columns
self.scalars = scalars
if len(scalars) == 0:
self.table = plc.Table(columns)
self.table = plc.Table([c.obj for c in columns])
else:
self.table = None

Expand All @@ -48,3 +54,64 @@ def __getitem__(self, name: str) -> Column | Scalar:
return self.scalars[i]
else:
return self.columns[i]

@cached_property
def num_rows(self):
"""Number of rows."""
if self.table is None:
raise ValueError("Number of rows of frame with scalars makes no sense")
return self.table.num_rows()

@classmethod
def from_cudf(cls, df: cudf.DataFrame) -> Self:
"""Create from a cudf dataframe."""
return cls(
[Column(c.to_pylibcudf(mode="read"), name) for name, c in df._data.items()],
[],
)

def with_columns(self, *columns: Column | Scalar) -> Self:
"""
Return a new dataframe with extra columns.
Data is shared.
"""
cols = [c for c in columns if isinstance(c, Column)]
scalars = [c for c in columns if isinstance(c, Scalar)]
return type(self)([*self.columns, *cols], [*self.scalars, *scalars])

def discard_columns(self, names: set[str]) -> Self:
"""Drop columns by name."""
return type(self)([c for c in self.columns if c not in names], self.scalars)

def replace_columns(self, *columns: Column) -> Self:
"""Return a new dataframe with columns replaced by name, maintaining order."""
new = {c.name: c for c in columns}
if set(new).intersection(self.scalar_names):
raise ValueError("Cannot replace scalars")
if not set(new).issubset(self.names):
raise ValueError("Cannot replace with non-existing names")
return type(self)([new.get(c.name, c) for c in self.columns], self.scalars)

def rename_columns(self, mapping: dict[str, str]) -> Self:
"""Rename some columns."""
new_columns = [
Column(c, mapping.get(c.name, c.name)).with_metadata(like=c)
for c in self.columns
]
return type(self)(new_columns, self.scalars)

def select_columns(self, names: set[str]) -> list[Column]:
"""Select columns by name."""
return [c for c in self.columns if c.name in names]

def filter(self, mask: Column) -> Self:
"""Return a filtered table given a mask."""
table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
return type(self)(
[
Column(new, old.name).with_metadata(like=old)
for old, new in zip(self.columns, table.columns())
],
[],
)
8 changes: 8 additions & 0 deletions python/cudf_polars/cudf_polars/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Utilities."""

from __future__ import annotations

__all__: list[str] = []
89 changes: 89 additions & 0 deletions python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Datatype utilities."""

from __future__ import annotations

from functools import cache

from typing_extensions import assert_never

import polars as pl

import cudf._lib.pylibcudf as plc


@cache
def from_polars(dtype: pl.DataType) -> plc.DataType:
"""
Convert a polars datatype to a pylibcudf one.
Parameters
----------
dtype
Polars dtype to convert
Returns
-------
Matching pylibcudf DataType object.
Raises
------
NotImplementedError for unsupported conversions.
"""
if isinstance(dtype, pl.Int8):
return plc.DataType(plc.TypeId.BOOL8)
elif isinstance(dtype, pl.Int8):
return plc.DataType(plc.TypeId.INT8)
elif isinstance(dtype, pl.Int16):
return plc.DataType(plc.TypeId.INT16)
elif isinstance(dtype, pl.Int32):
return plc.DataType(plc.TypeId.INT32)
elif isinstance(dtype, pl.Int64):
return plc.DataType(plc.TypeId.INT64)
if isinstance(dtype, pl.UInt8):
return plc.DataType(plc.TypeId.UINT8)
elif isinstance(dtype, pl.UInt16):
return plc.DataType(plc.TypeId.UINT16)
elif isinstance(dtype, pl.UInt32):
return plc.DataType(plc.TypeId.UINT32)
elif isinstance(dtype, pl.UInt64):
return plc.DataType(plc.TypeId.UINT64)
elif isinstance(dtype, pl.Float32):
return plc.DataType(plc.TypeId.FLOAT32)
elif isinstance(dtype, pl.Float64):
return plc.DataType(plc.TypeId.FLOAT64)
elif isinstance(dtype, pl.Date):
return plc.DataType(plc.TypeId.TIMESTAMP_DAYS)
elif isinstance(dtype, pl.Time):
raise NotImplementedError("Time of day dtype not implemented")
elif isinstance(dtype, pl.Datetime):
if dtype.time_zone is not None:
raise NotImplementedError("Time zone support")
if dtype.time_unit == "ms":
return plc.DataType(plc.TypeId.TIMESTAMP_MILLISECONDS)
elif dtype.time_unit == "us":
return plc.DataType(plc.TypeId.TIMESTAMP_MICROSECONDS)
elif dtype.time_unit == "ns":
return plc.DataType(plc.TypeId.TIMESTAMP_NANOSECONDS)
else:
assert dtype.time_unit is not None
assert_never(dtype.time_unit)
elif isinstance(dtype, pl.Duration):
if dtype.time_unit == "ms":
return plc.DataType(plc.TypeId.DURATION_MILLISECONDS)
elif dtype.time_unit == "us":
return plc.DataType(plc.TypeId.DURATION_MICROSECONDS)
elif dtype.time_unit == "ns":
return plc.DataType(plc.TypeId.DURATION_NANOSECONDS)
else:
assert dtype.time_unit is not None
assert_never(dtype.time_unit)
elif isinstance(dtype, pl.String):
return plc.DataType(plc.TypeId.STRING)
elif isinstance(dtype, pl.Null):
# TODO: Hopefully
return plc.DataType(plc.TypeId.EMPTY)
else:
raise NotImplementedError(f"{dtype=} conversion not supported")

0 comments on commit 3d7e9d8

Please sign in to comment.