Skip to content

Commit

Permalink
Refactor function invocation generation
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed Apr 18, 2024
1 parent a214297 commit 70c2d34
Showing 1 changed file with 38 additions and 32 deletions.
70 changes: 38 additions & 32 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,32 @@ def expression_from_sqlglot(self, sqlglot_node):
return self._parse_expression(sqlglot_node)

def _parse_expression(self, expr):
"""Parse a SQLGlot node and return a Substrait expression.
This is the internal implementation, expected to be
invoked in a recursive manner to parse the whole
expression tree.
"""
if isinstance(expr, sqlglot.expressions.Literal):
if expr.is_string:
return ParsedSubstraitExpression(
f"literal_{next(self._counter)}",
f"literal${next(self._counter)}",
proto.Type(string=proto.Type.String()),
proto.Expression(
literal=proto.Expression.Literal(string=expr.text)
),
)
elif expr.is_int:
return ParsedSubstraitExpression(
f"literal_{next(self._counter)}",
f"literal${next(self._counter)}",
proto.Type(i32=proto.Type.I32()),
proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
),
)
elif sqlglot.helper.is_float(expr.name):
return ParsedSubstraitExpression(
f"literal_{next(self._counter)}",
f"literal${next(self._counter)}",
proto.Type(fp32=proto.Type.FP32()),
proto.Expression(
literal=proto.Expression.Literal(float=float(expr.name))
Expand Down Expand Up @@ -144,11 +150,7 @@ def _parse_expression(self, expr):
argument_parsed_expr = self._parse_expression(expr.this)
function_name = SQL_UNARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(
function_name,
argument_parsed_expr.type,
argument_parsed_expr.expression,
)
self._parse_function_invokation(function_name, argument_parsed_expr)
)
result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
return ParsedSubstraitExpression(
Expand All @@ -163,11 +165,7 @@ def _parse_expression(self, expr):
function_name = SQL_BINARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(
function_name,
left_parsed_expr.type,
left_parsed_expr.expression,
right_parsed_expr.type,
right_parsed_expr.expression,
function_name, left_parsed_expr, right_parsed_expr
)
)
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
Expand All @@ -185,24 +183,27 @@ def _parse_expression(self, expr):
)

def _parse_function_invokation(
self, function_name, left_type, left, right_type=None, right=None
self, function_name, argument_parsed_expr, *additional_arguments
):
binary = False
argtypes = [left_type]
if right_type or right:
binary = True
argtypes.append(right_type)
signature = self._functions_catalog.signature(function_name, argtypes)
"""Generates a Substrait function invokation expression.
The function invocation will be generated from the function name
and the arguments as ParsedSubstraitExpression.
Returns the function signature, the return type and the
invokation expression itself.
"""
arguments = [argument_parsed_expr] + list(additional_arguments)
signature = self._functions_catalog.signature(
function_name, proto_argtypes=[arg.type for arg in arguments]
)

try:
function_anchor = self._functions_catalog.function_anchor(signature)
except KeyError:
# No function found with the exact types, try any1_any1 version
# TODO: What about cases like i32_any1? What about any instead of any1?
if binary:
signature = f"{function_name}:any1_any1"
else:
signature = f"{function_name}:any1"
signature = f"{function_name}:{'_'.join(['any1']*len(arguments))}"
function_anchor = self._functions_catalog.function_anchor(signature)

function_return_type = self._functions_catalog.function_return_type(signature)
Expand All @@ -216,20 +217,25 @@ def _parse_function_invokation(
proto.Expression(
scalar_function=proto.Expression.ScalarFunction(
function_reference=function_anchor,
arguments=(
[
proto.FunctionArgument(value=left),
proto.FunctionArgument(value=right),
]
if binary
else [proto.FunctionArgument(value=left)]
),
arguments=[
proto.FunctionArgument(value=arg.expression)
for arg in arguments
],
)
),
)


class ParsedSubstraitExpression:
"""A Substrait expression that was parsed from a SQLGlot node.
This stores the expression itself, with an associated output name
in case it is required to emit projections.
It also stores the type of the expression (i64, string, boolean, etc...)
and the functions that the expression in going to invoke.
"""

def __init__(self, output_name, type, expression, invoked_functions=None):
self.expression = expression
self.output_name = output_name
Expand Down

0 comments on commit 70c2d34

Please sign in to comment.