Skip to content

Commit

Permalink
WIP: More fleshing out evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed May 9, 2024
1 parent 3d7e9d8 commit da37a36
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 24 deletions.
14 changes: 10 additions & 4 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from cudf_polars.containers import Column, DataFrame

__all__ = [
"Expr",
"NamedExpr",
"Literal",
"Column",
"Col",
"BooleanFunction",
"Sort",
"SortBy",
Expand All @@ -37,7 +40,10 @@

@dataclass(slots=True)
class Expr:
pass
# TODO: return type is a lie for Literal
def evaluate(self, context: DataFrame) -> Column:
"""Evaluate this expression given a dataframe for context."""
raise NotImplementedError


@dataclass(slots=True)
Expand All @@ -53,7 +59,7 @@ class Literal(Expr):


@dataclass(slots=True)
class Column(Expr):
class Col(Expr):
name: str


Expand Down
198 changes: 179 additions & 19 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from functools import cache
from typing import TYPE_CHECKING, Any, Callable

import pyarrow as pa
from typing_extensions import assert_never

import polars as pl

import cudf
import cudf._lib.pylibcudf as plc

import cudf_polars.dsl.expr as expr
from cudf_polars.containers import Column, DataFrame
from cudf_polars.utils import dtypes

if TYPE_CHECKING:
from typing import Literal

from cudf_polars.dsl.expr import Expr


Expand Down Expand Up @@ -52,15 +63,16 @@
class IR:
schema: dict

def evaluate(self) -> DataFrame:
"""Evaluate and return a dataframe."""
raise NotImplementedError


@dataclass(slots=True)
class PythonScan(IR):
options: Any
predicate: Expr | None

def evaluate(self):
raise NotImplementedError


@dataclass(slots=True)
class Scan(IR):
Expand All @@ -70,34 +82,49 @@ class Scan(IR):
predicate: Expr | None

def __post_init__(self):
"""Validate preconditions."""
if self.file_options.n_rows is not None:
raise NotImplementedError("row limit in scan")
if self.typ not in ("csv", "parquet"):
raise NotImplementedError(f"Unhandled scan type: {self.typ}")
def evaluate(self):

def evaluate(self) -> DataFrame:
"""Evaluate and return a dataframe."""
options = self.file_options
n_rows = options.n_rows
with_columns = options.with_columns
row_index = options.row_index
assert n_rows is None
if self.typ == "csv":
df = cudf.concat(
[cudf.read_csv(p, usecols=with_columns) for p in self.paths]
df = DataFrame.from_cudf(
cudf.concat(
[cudf.read_csv(p, usecols=with_columns) for p in self.paths]
)
)
elif self.typ == "parquet":
df = cudf.read_parquet(self.paths, columns=with_columns)
df = DataFrame.from_cudf(
cudf.read_parquet(self.paths, columns=with_columns)
)
else:
assert_never(self.typ)
if row_index is not None:
name, offset = row_index
dtype = self.schema[name]
index = as_column(
..., dtype=dtype
dtype = dtypes.from_polars(self.schema[name])
step = plc.interop.from_arrow(pa.scalar(1), data_type=dtype)
init = plc.interop.from_arrow(pa.scalar(offset), data_type=dtype)
index = Column(
plc.filling.sequence(df.num_rows(), init, step), name
).set_sorted(
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.null_order.AFTER,
)




df = df.with_columns(index)
if self.predicate is None:
return df
else:
mask = self.predicate.evaluate(df)
return df.filter(mask)


@dataclass(slots=True)
Expand All @@ -112,13 +139,48 @@ class DataFrameScan(IR):
projection: list[str]
predicate: Expr | None

def evaluate(self) -> DataFrame:
"""Evaluate and return a dataframe."""
pdf = pl.DataFrame._from_pydf(self.df)
if self.projection is not None:
pdf = pdf.select(self.projection)
# TODO: goes away when libcudf supports large strings
table = pdf.to_arrow()
schema = table.schema
for i, field in enumerate(schema):
if field.type == pa.large_string():
# TODO: Nested types
schema = schema.set(i, pa.field(field.name, pa.string()))
table = table.cast(schema)
df = DataFrame(
[
Column(col, name)
for name, col in zip(
self.schema.keys(), plc.interop.from_arrow(table).columns()
)
],
[],
)
if self.predicate is not None:
mask = self.predicate.evaluate(df)
return df.filter(mask)
else:
return df


@dataclass(slots=True)
class Select(IR):
df: IR
cse: list[Expr]
expr: list[Expr]

def evaluate(self):
"""Evaluate and return a dataframe."""
df = self.df.evaluate()
for e in self.cse:
df = df.with_columns(e.evaluate(df))
return DataFrame([e.evaluate(df) for e in self.expr], [])


@dataclass(slots=True)
class GroupBy(IR):
Expand Down Expand Up @@ -174,11 +236,109 @@ class Join(IR):

def __post_init__(self):
"""Raise for unsupported options."""
how, coalesce = self.options[0], self.options[-1]
if how == "cross":
if self.options[0] == "cross":
raise NotImplementedError("cross join not implemented")
if how == "outer" and not coalesce:
raise NotImplementedError("non-coalescing outer join")

@cache
@staticmethod
def _joiners(
how: Literal["inner", "left", "outer", "leftsemi", "leftanti"],
) -> tuple[
Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None
]:
if how == "inner":
return (
plc.join.inner_join,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
)
elif how == "left":
return (
plc.join.left_join,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
plc.copying.OutOfBoundsPolicy.NULLIFY,
)
elif how == "outer":
return (
plc.join.full_join,
plc.copying.OutOfBoundsPolicy.NULLIFY,
plc.copying.OutOfBoundsPolicy.NULLIFY,
)
elif how == "leftsemi":
return (
plc.join.left_semi_join,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
None,
)
elif how == "leftanti":
return (
plc.join.left_anti_join,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
None,
)
else:
assert_never(how)

def evaluate(self) -> DataFrame:
"""Evaluate and return a dataframe."""
left = self.left.evaluate()
right = self.right.evaluate()
left_on = DataFrame([e.evaluate(left) for e in self.left_on], [])
right_on = DataFrame([e.evaluate(right) for e in self.right_on], [])
how, join_nulls, zlice, suffix, coalesce = self.options
null_equality = (
plc.types.NullEquality.EQUAL
if join_nulls
else plc.types.NullEquality.UNEQUAL
)
suffix = "_right" if suffix is None else suffix
join_fn, left_policy, right_policy = Join._joiners(how)
if right_policy is None:
# Semi join
lg = join_fn(left_on.table, right_on.table, null_equality)
left = left.replace_columns(*left_on.columns)
table = plc.copying.gather(left.table, lg, left_policy)
result = DataFrame(
[
Column(c, col.name)
for col, c in zip(left_on.columns, table.columns())
],
[],
)
else:
lg, rg = join_fn(left_on, right_on, null_equality)
left = left.replace_columns(*left_on.columns)
right = right.replace_columns(*right_on.columns)
if coalesce and how != "outer":
right = right.discard_columns(set(right_on.names))
left = DataFrame(
plc.copying.gather(left.table, lg, left_policy).columns(), []
)
right = DataFrame(
plc.copying.gather(right.table, rg, right_policy).columns(), []
)
if coalesce and how == "outer":
left.replace_columns(
*(
Column(
plc.replace.replace_nulls(left_col.obj, right_col.obj),
left_col.name,
)
for left_col, right_col in zip(
left.select_columns(set(left_on.names)),
right.select_columns(set(right_on.names)),
)
)
)
right.discard_columns(set(right_on.names))
right = right.rename_columns(
{name: f"{name}{suffix}" for name in right.names if name in left.names}
)
result = left.with_columns(*right.columns)
if zlice is not None:
raise NotImplementedError("slicing")
else:
return result


@dataclass(slots=True)
Expand Down
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@


class set_node(AbstractContextManager):
"""Run a block with current node set in the visitor."""

__slots__ = ("n", "visitor")

def __init__(self, visitor, n):
Expand Down Expand Up @@ -242,7 +244,7 @@ def translate_expr(visitor: Any, *, n: int | pl_expr.PyExprIR) -> expr.Expr:
elif isinstance(node, pl_expr.Cast):
return expr.Cast(node.dtype, translate_expr(visitor, n=node.expr))
elif isinstance(node, pl_expr.Column):
return expr.Column(node.name)
return expr.Col(node.name)
elif isinstance(node, pl_expr.Agg):
return expr.Agg(
translate_expr(visitor, n=node.arguments),
Expand Down

0 comments on commit da37a36

Please sign in to comment.