Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Report all unsupported operations for a query in cudf.polars #16960

Draft
wants to merge 4 commits into
base: branch-24.12
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def execute_with_cudf(
device = config.device
memory_resource = config.memory_resource
raise_on_fail = config.config.get("raise_on_fail", False)
if unsupported := (config.config.keys() - {"raise_on_fail"}):
debug_mode = config.config.get("debug_mode", False)
if unsupported := (config.config.keys() - {"raise_on_fail", "debug_mode"}):
raise ValueError(
f"Engine configuration contains unsupported settings {unsupported}"
)
Expand All @@ -183,7 +184,10 @@ def execute_with_cudf(
nt.set_udf(
partial(
_callback,
translate_ir(nt),
translate_ir(
nt,
debug_mode=1 if debug_mode else 0,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass a more general config object?

),
device=device,
memory_resource=memory_resource,
)
Expand All @@ -197,3 +201,5 @@ def execute_with_cudf(
)
if raise_on_fail:
raise
if debug_mode:
raise
5 changes: 5 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
) # pragma: no cover


@dataclasses.dataclass
class ErrorNode(IR):
error: str


@dataclasses.dataclass
class PythonScan(IR):
"""Representation of input from a python function."""
Expand Down
135 changes: 113 additions & 22 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@
__all__ = ["translate_ir", "translate_named_expr"]


def debug(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except NotImplementedError as e:
if kwargs.get("debug_mode", False):
return ir.ErrorNode(args[0].get_schema(), e)
raise

return wrapper


class set_node(AbstractContextManager[None]):
"""
Run a block with current node set in the visitor.
Expand Down Expand Up @@ -64,16 +76,24 @@ def __exit__(self, *args: Any) -> None:

@singledispatch
def _translate_ir(
node: Any, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: Any,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
raise NotImplementedError(
f"Translation for {type(node).__name__}"
) # pragma: no cover
e = f"Translation for {type(node).__name__}"
if debug_mode:
return ir.ErrorNode(schema, e)
raise NotImplementedError(e) # pragma: no cover


@debug
@_translate_ir.register
def _(
node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.PythonScan,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
scan_fn, with_columns, source_type, predicate, nrows = node.options
options = (scan_fn, with_columns, source_type, nrows)
Expand All @@ -83,9 +103,13 @@ def _(
return ir.PythonScan(schema, options, predicate)


@debug
@_translate_ir.register
def _(
node: pl_ir.Scan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Scan,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
typ, *options = node.scan_type
if typ == "ndjson":
Expand Down Expand Up @@ -120,16 +144,24 @@ def _(
)


@debug
@_translate_ir.register
def _(
node: pl_ir.Cache, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Cache,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Cache(schema, node.id_, translate_ir(visitor, n=node.input))


@debug
@_translate_ir.register
def _(
node: pl_ir.DataFrameScan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.DataFrameScan,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.DataFrameScan(
schema,
Expand All @@ -141,19 +173,27 @@ def _(
)


@debug
@_translate_ir.register
def _(
node: pl_ir.Select, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Select,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
exprs = [translate_named_expr(visitor, n=e) for e in node.expr]
return ir.Select(schema, inp, exprs, node.should_broadcast)


@debug
@_translate_ir.register
def _(
node: pl_ir.GroupBy, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.GroupBy,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
Expand All @@ -169,9 +209,13 @@ def _(
)


@debug
@_translate_ir.register
def _(
node: pl_ir.Join, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Join,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
# Join key dtypes are dependent on the schema of the left and
# right inputs, so these must be translated with the relevant
Expand All @@ -185,29 +229,41 @@ def _(
return ir.Join(schema, inp_left, inp_right, left_on, right_on, node.options)


@debug
@_translate_ir.register
def _(
node: pl_ir.HStack, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.HStack,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
exprs = [translate_named_expr(visitor, n=e) for e in node.exprs]
return ir.HStack(schema, inp, exprs, node.should_broadcast)


@debug
@_translate_ir.register
def _(
node: pl_ir.Reduce, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Reduce,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
exprs = [translate_named_expr(visitor, n=e) for e in node.expr]
return ir.Reduce(schema, inp, exprs)


@debug
@_translate_ir.register
def _(
node: pl_ir.Distinct, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Distinct,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Distinct(
schema,
Expand All @@ -216,45 +272,63 @@ def _(
)


@debug
@_translate_ir.register
def _(
node: pl_ir.Sort, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Sort,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
by = [translate_named_expr(visitor, n=e) for e in node.by_column]
return ir.Sort(schema, inp, by, node.sort_options, node.slice)


@debug
@_translate_ir.register
def _(
node: pl_ir.Slice, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Slice,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Slice(schema, translate_ir(visitor, n=node.input), node.offset, node.len)


@debug
@_translate_ir.register
def _(
node: pl_ir.Filter, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Filter,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
mask = translate_named_expr(visitor, n=node.predicate)
return ir.Filter(schema, inp, mask)


@debug
@_translate_ir.register
def _(
node: pl_ir.SimpleProjection,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Projection(schema, translate_ir(visitor, n=node.input))


@debug
@_translate_ir.register
def _(
node: pl_ir.MapFunction, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.MapFunction,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
name, *options = node.function
return ir.MapFunction(
Expand All @@ -266,23 +340,37 @@ def _(
)


@debug
@_translate_ir.register
def _(
node: pl_ir.Union, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Union,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Union(
schema, [translate_ir(visitor, n=n) for n in node.inputs], node.options
)


@debug
@_translate_ir.register
def _(
node: pl_ir.HConcat, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.HConcat,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.HConcat(schema, [translate_ir(visitor, n=n) for n in node.inputs])


def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
@debug
def translate_ir(
visitor: NodeTraverser,
*,
n: int | None = None,
debug_mode: int = 0,
) -> ir.IR:
"""
Translate a polars-internal IR node to our representation.

Expand All @@ -293,6 +381,9 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
n
Optional node to start traversing from, if not provided uses
current polars-internal node.
debug_mode
Optional: If true returns an ErrorNode in the IR that is used to
report unsupported operations in the query

Returns
-------
Expand All @@ -319,7 +410,7 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
polars_schema = visitor.get_schema()
node = visitor.view_current_node()
schema = {k: dtypes.from_polars(v) for k, v in polars_schema.items()}
result = _translate_ir(node, visitor, schema)
result = _translate_ir(node, visitor, schema, debug_mode=debug_mode)
if any(
isinstance(dtype, pl.Null)
for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values())
Expand Down
Loading