Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate CASE WHEN expressions lazily #66

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions src/NQuery.Tests/Evaluation/EagerAndLazyTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
using System.Linq.Expressions;

using NQuery.Symbols;

namespace NQuery.Tests.Evaluation
{
public sealed class EagerAndLazyTests
{
private static InvocationResult EvaluateAndCountInvocations(string text)
{
var invocationResult = new InvocationResult();
var invocationResultVariable = new VariableSymbol("ir", typeof(InvocationResult), invocationResult);
var nullInt32Function = new InvocationResultFunctionSymbol<int?>("NULL_INT32", NullInt32Function);
var nonNullInt32Function = new InvocationResultFunctionSymbol<int?>("NON_NULL_INT32", NonNullInt32Function);
var dataContext = DataContext.Default
.AddVariables(invocationResultVariable)
.AddFunctions(nullInt32Function, nonNullInt32Function);
var expression = Expression<object>.Create(dataContext, text);
invocationResult.Result = expression.Evaluate();
return invocationResult;
}

[Fact]
public void Evaluation_Conversion_Once()
{
var result = EvaluateAndCountInvocations("CAST(NON_NULL_INT32(ir) AS int64)");
Assert.Equal(42L, result.Result);
Assert.Equal(1, result.NonNullInt32FunctionCount);
}

[Fact]
public void Evaluation_Unary_Once()
{
var result = EvaluateAndCountInvocations("~NON_NULL_INT32(ir)");
Assert.Equal(~42, result.Result);
Assert.Equal(1, result.NonNullInt32FunctionCount);
}

[Fact]
public void Evaluation_Binary_EagerOnce()
{
var result = EvaluateAndCountInvocations("NULL_INT32(ir) + NON_NULL_INT32(ir)");
Assert.Null(result.Result);
Assert.Equal(1, result.NullInt32FunctionCount);
Assert.Equal(1, result.NonNullInt32FunctionCount);
}

[Fact]
public void Evaluation_FunctionInvocation_EagerOnce()
{
var result = EvaluateAndCountInvocations("SUBSTRING('abc', NULL_INT32(ir), NON_NULL_INT32(ir))");
Assert.Null(result.Result);
Assert.Equal(1, result.NullInt32FunctionCount);
Assert.Equal(1, result.NonNullInt32FunctionCount);
}

[Fact]
public void Evaluation_MethodInvocation_Instance_Once()
{
var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir).Equals(42)");
Assert.Equal(true, result.Result);
Assert.Equal(1, result.NonNullInt32FunctionCount);
}

[Fact]
public void Evaluation_MethodInvocation_Arguments_EagerOnce()
{
var result = EvaluateAndCountInvocations("''.Substring(NULL_INT32(ir), NON_NULL_INT32(ir))");
Assert.Null(result.Result);
Assert.Equal(1, result.NullInt32FunctionCount);
Assert.Equal(1, result.NonNullInt32FunctionCount);
}

[Fact]
public void Evaluation_PropertyAccess_Once()
{
var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir).Equals(42)");
Assert.Equal(true, result.Result);
Assert.Equal(1, result.NonNullInt32FunctionCount);
}

[Fact]
public void Evaluation_IsNull_Once()
{
var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir) IS NOT NULL");
Assert.Equal(true, result.Result);
Assert.Equal(1, result.NonNullInt32FunctionCount);
}

[Fact]
public void Evaluation_CaseWhen_NonNullFunction_LazyOnce()
{
const string text = @"
CASE
WHEN NON_NULL_INT32(ir) = 42 THEN 42
ELSE NULL_INT32(ir)
END";

var result = EvaluateAndCountInvocations(text);
Assert.Equal(42, result.Result);
Assert.Equal(1, result.NonNullInt32FunctionCount);
Assert.Equal(0, result.NullInt32FunctionCount);
}

[Fact]
public void Evaluation_CaseWhen_NonNullNestedFunction_LazyOnce()
{
const string text = @"
CASE
WHEN TO_INT32(NON_NULL_INT32(ir)) = 42 THEN 42
WHEN TO_INT32(NON_NULL_INT32(ir)) != 42 THEN 0
ELSE TO_INT32(NULL_INT32(ir))
END";

var result = EvaluateAndCountInvocations(text);
Assert.Equal(42, result.Result);
Assert.Equal(1, result.NonNullInt32FunctionCount);
Assert.Equal(0, result.NullInt32FunctionCount);
}

[Fact]
public void Evaluation_CaseWhen_NullFunction_LazyOnce()
{
const string text = @"
CASE
WHEN NULL_INT32(ir) = 0 THEN 42
ELSE 0
END";

var result = EvaluateAndCountInvocations(text);
Assert.Equal(0, result.Result);
Assert.Equal(1, result.NullInt32FunctionCount);
}

private static int? NullInt32Function(InvocationResult ir)
{
ir.NullInt32FunctionCount++;
return null;
}

private static int? NonNullInt32Function(InvocationResult ir)
{
ir.NonNullInt32FunctionCount++;
return 42;
}

private sealed class InvocationResult
{
public object Result { get; set; }
public int NullInt32FunctionCount { get; set; }
public int NonNullInt32FunctionCount { get; set; }
}

private sealed class InvocationResultFunctionSymbol<TResult> : FunctionSymbol
{
public InvocationResultFunctionSymbol(string name, Func<InvocationResult, TResult> function)
: base(name, typeof(TResult).GetNonNullableType(), new ParameterSymbol("ir", typeof(InvocationResult)))
{
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
{
return Expression.Call(Function.Method, arguments);
}

private Func<InvocationResult, TResult> Function { get; }
}
}
}
42 changes: 26 additions & 16 deletions src/NQuery/Iterators/ExpressionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ public static IteratorPredicate BuildIteratorPredicate(BoundExpression predicate
return BuildExpression<IteratorPredicate>(predicate, typeof(bool), allocation);
}

private static TDelegate BuildExpression<TDelegate>(BoundExpression expression, Type targetType, RowBufferAllocation allocation)
private static TDelegate BuildExpression<TDelegate>(BoundExpression expression, Type targetType, RowBufferAllocation allocation) where TDelegate : Delegate
{
var lambda = BuildExpression(expression, typeof(TDelegate), targetType, allocation);
return (TDelegate)lambda.Compile();
}

private static LambdaExpression BuildExpression(BoundExpression expression, Type delegateType, Type targetType, RowBufferAllocation allocation)
{
var builder = new ExpressionBuilder(allocation);
return builder.BuildExpression<TDelegate>(expression, targetType);
return builder.BuildExpression(expression, delegateType, targetType);
}

private ParameterExpression BuildCachedExpression(BoundExpression expression)
private Expression BuildCachedExpression(BoundExpression expression)
{
var result = BuildExpression(expression);
var liftedExpression = BuildLiftedExpression(result);
Expand All @@ -63,11 +69,6 @@ private static Expression BuildLiftedExpression(Expression result)
: Expression.Convert(result, result.Type.GetNullableType());
}

private Expression BuildLiftedExpression(BoundExpression expression)
{
return BuildLiftedExpression(BuildExpression(expression));
}

private static Expression BuildLoweredExpression(Expression expression)
{
if (!expression.Type.IsNullableOfT())
Expand Down Expand Up @@ -113,7 +114,7 @@ private static Expression BuildInvocation(MethodSymbol methodSymbol, Expression
return
BuildLiftedExpression(
methodSymbol.CreateInvocation(
BuildLoweredExpression(instance),
BuildLoweredExpression(instance),
arguments.Select(BuildLoweredExpression)
)
);
Expand Down Expand Up @@ -144,17 +145,17 @@ private static UnaryExpression BuildNullableTrue()
return Expression.Convert(Expression.Constant(true), typeof(bool?));
}

private TDelegate BuildExpression<TDelegate>(BoundExpression expression, Type targetType)
private LambdaExpression BuildExpression(BoundExpression expression, Type delegateType, Type targetType)
{
var actualExpression = BuildCachedExpression(expression);
var coalescedExpression = targetType.CanBeNull()
? (Expression)actualExpression
? actualExpression
: Expression.Coalesce(actualExpression, Expression.Default(targetType));
var resultExpression = Expression.Convert(coalescedExpression, targetType);
var expressions = _assignments.Concat(new[] { resultExpression });
var body = Expression.Block(_locals, expressions);
var lambda = Expression.Lambda<TDelegate>(body);
return lambda.Compile();
var lambda = Expression.Lambda(delegateType, body);
return lambda;
}

private Expression BuildExpression(BoundExpression expression)
Expand Down Expand Up @@ -455,7 +456,7 @@ private Expression BuildCaseLabel(BoundCaseExpression caseExpression, int caseLa
if (caseLabelIndex == caseExpression.CaseLabels.Length)
return caseExpression.ElseExpression is null
? BuildNullValue(caseExpression.Type)
: BuildLiftedExpression(caseExpression.ElseExpression);
: BuildNestedScopeInvocation(caseExpression.ElseExpression);

var caseLabel = caseExpression.CaseLabels[caseLabelIndex];
var condition = caseLabel.Condition;
Expand All @@ -464,12 +465,21 @@ private Expression BuildCaseLabel(BoundCaseExpression caseExpression, int caseLa
return
Expression.Condition(
Expression.Equal(
BuildLiftedExpression(condition),
BuildNestedScopeInvocation(condition),
BuildNullableTrue()
),
BuildLiftedExpression(result),
BuildNestedScopeInvocation(result),
BuildCaseLabel(caseExpression, caseLabelIndex + 1)
);
}

private Expression BuildNestedScopeInvocation(BoundExpression expression)
{
var targetType = expression.Type;
var delegateType = typeof(Func<>).MakeGenericType(targetType);
var lambda = BuildExpression(expression, delegateType, targetType, _rowBufferAllocation);
var invocation = Expression.Invoke(lambda);
return BuildLiftedExpression(invocation);
}
}
}
Loading