Skip to content

Commit

Permalink
WIP: More fleshing out
Browse files Browse the repository at this point in the history
Still need to port the expression eval
  • Loading branch information
wence- committed May 9, 2024
1 parent da37a36 commit a37cdbb
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 55 deletions.
15 changes: 9 additions & 6 deletions python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ def __init__(self, column: plc.Column, name: str):
self.name = name
self.is_sorted = plc.types.Sorted.NO

def with_metadata(self, *, like: Column) -> Self:
"""Copy metadata from a column onto self."""
self.is_sorted = like.is_sorted
self.order = like.order
self.null_order = like.null_order
return self
def rename(self, name: str) -> Column:
"""Return a new column sharing data with a new name."""
return type(self)(self.obj, name).with_sorted(like=self)

def with_sorted(self, *, like: Column) -> Self:
"""Copy sortedness properties from a column onto self."""
return self.set_sorted(
is_sorted=like.is_sorted, order=like.order, null_order=like.null_order
)

def set_sorted(
self,
Expand Down
69 changes: 56 additions & 13 deletions python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def __getitem__(self, name: str) -> Column | Scalar:
else:
return self.columns[i]

@cached_property
def column_names(self) -> list[str]:
"""Return a list of the column names."""
return [c.name for c in self.columns]

@cached_property
def num_columns(self):
"""Number of columns."""
return len(self.columns)

@cached_property
def num_rows(self):
"""Number of rows."""
Expand All @@ -70,6 +80,22 @@ def from_cudf(cls, df: cudf.DataFrame) -> Self:
[],
)

@classmethod
def from_table(cls, table: plc.Table, names: list[str]) -> Self:
"""Create from a pylibcudf table."""
if table.num_columns != len(names):
raise ValueError("Mismatching name and table length.")
return cls([Column(c, name) for c, name in zip(table.columns(), names)], [])

def with_sorted(self, *, like: DataFrame) -> Self:
"""Copy sortedness from a dataframe onto self."""
if like.column_names != self.column_names:
raise ValueError("Can only copy from identically named frame")
self.columns = [
c.with_sorted(like=other) for c, other in zip(self.columns, like.columns)
]
return self

def with_columns(self, *columns: Column | Scalar) -> Self:
"""
Return a new dataframe with extra columns.
Expand All @@ -85,7 +111,7 @@ def discard_columns(self, names: set[str]) -> Self:
return type(self)([c for c in self.columns if c not in names], self.scalars)

def replace_columns(self, *columns: Column) -> Self:
"""Return a new dataframe with columns replaced by name, maintaining order."""
"""Return a new dataframe with columns replaced by name."""
new = {c.name: c for c in columns}
if set(new).intersection(self.scalar_names):
raise ValueError("Cannot replace scalars")
Expand All @@ -95,11 +121,9 @@ def replace_columns(self, *columns: Column) -> Self:

def rename_columns(self, mapping: dict[str, str]) -> Self:
"""Rename some columns."""
new_columns = [
Column(c, mapping.get(c.name, c.name)).with_metadata(like=c)
for c in self.columns
]
return type(self)(new_columns, self.scalars)
return type(self)(
[c.rename(mapping.get(c.name, c.name)) for c in self.columns], self.scalars
)

def select_columns(self, names: set[str]) -> list[Column]:
"""Select columns by name."""
Expand All @@ -108,10 +132,29 @@ def select_columns(self, names: set[str]) -> list[Column]:
def filter(self, mask: Column) -> Self:
"""Return a filtered table given a mask."""
table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
return type(self)(
[
Column(new, old.name).with_metadata(like=old)
for old, new in zip(self.columns, table.columns())
],
[],
)
return type(self).from_table(table, self.column_names).with_sorted(like=self)

def slice(self, zlice: tuple[int, int] | None) -> Self:
"""
Slice a dataframe.
Parameters
----------
zlice
optional, tuple of start and length, negative values of start
treated as for python indexing. If not provided, returns self.
Returns
-------
New dataframe (if zlice is not None) other self (if it is)
"""
if zlice is None:
return self
start, length = zlice
if start < 0:
start += self.num_rows
# Polars slice takes an arbitrary positive integer and slice
# to the end of the frame if it is larger.
end = min(start + length, self.num_rows)
(table,) = plc.copying.slice(self.table, [start, end])
return type(self).from_table(table, self.column_names).with_sorted(like=self)
163 changes: 127 additions & 36 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from dataclasses import dataclass
from functools import cache
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, ClassVar

import pyarrow as pa
from typing_extensions import assert_never
Expand All @@ -29,7 +29,7 @@

import cudf_polars.dsl.expr as expr
from cudf_polars.containers import Column, DataFrame
from cudf_polars.utils import dtypes
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from typing import Literal
Expand Down Expand Up @@ -63,7 +63,7 @@
class IR:
schema: dict

def evaluate(self) -> DataFrame:
def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
raise NotImplementedError

Expand All @@ -88,7 +88,7 @@ def __post_init__(self):
if self.typ not in ("csv", "parquet"):
raise NotImplementedError(f"Unhandled scan type: {self.typ}")

def evaluate(self) -> DataFrame:
def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
options = self.file_options
n_rows = options.n_rows
Expand Down Expand Up @@ -132,14 +132,21 @@ class Cache(IR):
key: int
value: IR

def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
try:
return cache[self.key]
except KeyError:
return cache.setdefault(self.key, self.value.evaluate(cache=cache))


@dataclass(slots=True)
class DataFrameScan(IR):
df: Any
projection: list[str]
predicate: Expr | None

def evaluate(self) -> DataFrame:
def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
pdf = pl.DataFrame._from_pydf(self.df)
if self.projection is not None:
Expand All @@ -152,14 +159,8 @@ def evaluate(self) -> DataFrame:
# TODO: Nested types
schema = schema.set(i, pa.field(field.name, pa.string()))
table = table.cast(schema)
df = DataFrame(
[
Column(col, name)
for name, col in zip(
self.schema.keys(), plc.interop.from_arrow(table).columns()
)
],
[],
df = DataFrame.from_table(
plc.interop.from_arrow(table), list(self.schema.keys())
)
if self.predicate is not None:
mask = self.predicate.evaluate(df)
Expand All @@ -174,9 +175,9 @@ class Select(IR):
cse: list[Expr]
expr: list[Expr]

def evaluate(self):
def evaluate(self, *, cache: dict[int, DataFrame]):
"""Evaluate and return a dataframe."""
df = self.df.evaluate()
df = self.df.evaluate(cache=cache)
for e in self.cse:
df = df.with_columns(e.evaluate(df))
return DataFrame([e.evaluate(df) for e in self.expr], [])
Expand Down Expand Up @@ -235,7 +236,7 @@ class Join(IR):
options: Any

def __post_init__(self):
"""Raise for unsupported options."""
"""Validate preconditions."""
if self.options[0] == "cross":
raise NotImplementedError("cross join not implemented")

Expand Down Expand Up @@ -279,10 +280,10 @@ def _joiners(
else:
assert_never(how)

def evaluate(self) -> DataFrame:
def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
left = self.left.evaluate()
right = self.right.evaluate()
left = self.left.evaluate(cache=cache)
right = self.right.evaluate(cache=cache)
left_on = DataFrame([e.evaluate(left) for e in self.left_on], [])
right_on = DataFrame([e.evaluate(right) for e in self.right_on], [])
how, join_nulls, zlice, suffix, coalesce = self.options
Expand All @@ -298,24 +299,18 @@ def evaluate(self) -> DataFrame:
lg = join_fn(left_on.table, right_on.table, null_equality)
left = left.replace_columns(*left_on.columns)
table = plc.copying.gather(left.table, lg, left_policy)
result = DataFrame(
[
Column(c, col.name)
for col, c in zip(left_on.columns, table.columns())
],
[],
)
result = DataFrame.from_table(table, left.column_names)
else:
lg, rg = join_fn(left_on, right_on, null_equality)
left = left.replace_columns(*left_on.columns)
right = right.replace_columns(*right_on.columns)
if coalesce and how != "outer":
right = right.discard_columns(set(right_on.names))
left = DataFrame(
plc.copying.gather(left.table, lg, left_policy).columns(), []
left = DataFrame.from_table(
plc.copying.gather(left.table, lg, left_policy), left.column_names
)
right = DataFrame(
plc.copying.gather(right.table, rg, right_policy).columns(), []
right = DataFrame.from_table(
plc.copying.gather(right.table, rg, right_policy), right.column_names
)
if coalesce and how == "outer":
left.replace_columns(
Expand All @@ -335,29 +330,125 @@ def evaluate(self) -> DataFrame:
{name: f"{name}{suffix}" for name in right.names if name in left.names}
)
result = left.with_columns(*right.columns)
if zlice is not None:
raise NotImplementedError("slicing")
else:
return result
return result.slice(zlice)


@dataclass(slots=True)
class HStack(IR):
df: IR
columns: list[Expr]

def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
df = self.df.evaluate(cache=cache)
return df.with_columns(*(c.evaluate(df) for c in self.columns))


@dataclass(slots=True)
class Distinct(IR):
df: IR
options: Any
keep: plc.stream_compaction.DuplicateKeepOption
subset: set[str] | None
zlice: tuple[int, int] | None
stable: bool

_KEEP_MAP: ClassVar[dict[str, plc.stream_compaction.DuplicateKeepOption]] = {
"first": plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
"last": plc.stream_compaction.DuplicateKeepOption.KEEP_LAST,
"none": plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
"any": plc.stream_compaction.DuplicateKeepOption.KEEP_ANY,
}

def __init__(self, schema: dict, df: IR, options: Any):
self.schema = schema
self.df = df
(keep, subset, maintain_order, zlice) = options
self.keep = Distinct._KEEP_MAP[keep]
self.subset = set(subset) if subset is not None else None
self.stable = maintain_order
self.zlice = zlice

def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
df = self.df.evaluate(cache=cache)
if self.subset is None:
indices = list(range(df.num_columns))
else:
indices = [i for i, k in enumerate(df.names) if k in self.subset]
keys_sorted = all(c.is_sorted for c in df.columns)
if keys_sorted:
table = plc.stream_compaction.unique(
df.table,
indices,
self.keep,
plc.types.NullEquality.EQUAL,
)
else:
distinct = (
plc.stream_compaction.stable_distinct
if self.stable
else plc.stream_compaction.distinct
)
table = distinct(
df.table,
indices,
self.keep,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
result = DataFrame(
[Column(c, old.name) for c, old in zip(table.columns(), df.columns)], []
)
if keys_sorted or self.stable:
result = result.with_sorted(like=df)
return result.slice(self.zlice)


@dataclass(slots=True)
class Sort(IR):
df: IR
by: list[Expr]
options: Any
do_sort: Callable[..., plc.Table]
zlice: tuple[int, int] | None
order: list[plc.types.Order]
null_order: list[plc.types.NullOrder]

def __init__(self, schema: dict, df: IR, by: list[Expr], options: Any):
self.schema = schema
self.df = df
self.by = by
stable, nulls_last, descending = options
self.order, self.null_order = sorting.sort_order(
descending, nulls_last=nulls_last, num_keys=len(by)
)
self.do_sort = (
plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
)

def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
df = self.df.evaluate(cache=cache)
sort_keys = [k.evaluate(df) for k in self.by]
keys_in_result = [
i
for k in sort_keys
if (i := df.names.get(k.name)) is not None and k is df.columns[i]
]
table = self.do_sort(
df.table,
plc.Table([k.obj for k in sort_keys]),
self.order,
self.null_order,
)
columns = [Column(c, old.name) for c, old in zip(table.columns(), df.columns)]
# If a sort key is in the result table, set the sortedness property
for idx in keys_in_result:
columns[idx] = columns[idx].set_sorted(
is_sorted=plc.types.Sorted.YES,
order=self.order[idx],
null_order=self.null_order[idx],
)
return DataFrame(columns, [])


@dataclass(slots=True)
Expand Down
Loading

0 comments on commit a37cdbb

Please sign in to comment.