Last active
November 12, 2021 16:44
-
-
Save eliottrobson/2a75aa0ef4def5e167c44b662ad55b70 to your computer and use it in GitHub Desktop.
Boolean Expression Reducer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// <summary> | |
/// Reduces boolean expressions to improve performance with generated queries. | |
/// <example> | |
/// For example: | |
/// <code> | |
/// Expression<Func<User, FollowingData>> projection = u => new FollowingData | |
/// { | |
/// UserId = u.Id, | |
/// FollowersCount = u.Followers.Count, | |
/// IsFollowing = currentUserId.HasValue && u.Following.Any(f => f.UserId == currentUserId.Value) | |
/// }; | |
/// | |
/// return await _databaseContext.Users | |
/// .Select(new BooleanExpressionReducer().Reduce(projection)) | |
/// .ToListAsync(); | |
/// </code> | |
/// will omit the <c>IsFollowing</c> query (and return false) if <c>currentUserId</c> is null. | |
/// </example> | |
/// </summary> | |
public class BooleanExpressionReducer : ExpressionVisitor | |
{ | |
public Expression<Func<TSource, TDest>> Reduce<TSource, TDest>( | |
Expression<Func<TSource, TDest>> expr) | |
{ | |
return Visit(expr) as Expression<Func<TSource, TDest>>; | |
} | |
protected override Expression VisitConditional(ConditionalExpression node) | |
{ | |
var test = Visit(node.Test); | |
// The conditional is now a constant, we can replace the branch | |
if (test is ConstantExpression testNode) | |
{ | |
var value = (dynamic) testNode.Value; | |
return value ? Visit(node.IfTrue) : Visit(node.IfFalse); | |
} | |
// If it is not a conditional, we follow the default behaviour | |
return base.VisitConditional(node); | |
} | |
protected override Expression VisitMember(MemberExpression node) | |
{ | |
if (node.Type != typeof(bool)) | |
return base.VisitMember(node); | |
Expression nodeRoot = node; | |
while (nodeRoot is MemberExpression nodeRootMember) | |
nodeRoot = nodeRootMember.Expression; | |
if (nodeRoot.NodeType != ExpressionType.Constant && | |
nodeRoot.NodeType != ExpressionType.MemberAccess) | |
return base.VisitMember(node); | |
var objectMember = Expression.Convert(node, typeof(object)); | |
var getterLambda = Expression.Lambda<Func<object>>(objectMember); | |
var getter = getterLambda.Compile(); | |
var value = getter(); | |
return Expression.Constant(value); | |
} | |
protected override Expression VisitBinary(BinaryExpression node) | |
{ | |
// Special optimisations for boolean expressions | |
var optimised = OptimiseBooleanBinaryExpression(node.NodeType, node.Left, node.Right); | |
if (optimised.Result != null) return optimised.Result; | |
var leftConst = optimised.LeftVisit as ConstantExpression; | |
var rightConst = optimised.RightVisit as ConstantExpression; | |
if (leftConst == null || rightConst == null) | |
return base.VisitBinary(node); | |
var leftValue = (dynamic) leftConst.Value; | |
var rightValue = (dynamic) rightConst.Value; | |
switch (node.NodeType) | |
{ | |
case ExpressionType.Add: | |
return Expression.Constant(leftValue + rightValue); | |
case ExpressionType.Divide: | |
return Expression.Constant(leftValue / rightValue); | |
case ExpressionType.Modulo: | |
return Expression.Constant(leftValue % rightValue); | |
case ExpressionType.Multiply: | |
return Expression.Constant(leftValue * rightValue); | |
case ExpressionType.Power: | |
return Expression.Constant(leftValue ^ rightValue); | |
case ExpressionType.Subtract: | |
return Expression.Constant(leftValue - rightValue); | |
case ExpressionType.And: | |
return Expression.Constant(leftValue & rightValue); | |
case ExpressionType.AndAlso: | |
return Expression.Constant(leftValue && rightValue); | |
case ExpressionType.Or: | |
return Expression.Constant(leftValue | rightValue); | |
case ExpressionType.OrElse: | |
return Expression.Constant(leftValue || rightValue); | |
case ExpressionType.Equal: | |
return Expression.Constant(leftValue == rightValue); | |
case ExpressionType.NotEqual: | |
return Expression.Constant(leftValue != rightValue); | |
case ExpressionType.GreaterThan: | |
return Expression.Constant(leftValue > rightValue); | |
case ExpressionType.GreaterThanOrEqual: | |
return Expression.Constant(leftValue >= rightValue); | |
case ExpressionType.LessThan: | |
return Expression.Constant(leftValue < rightValue); | |
case ExpressionType.LessThanOrEqual: | |
return Expression.Constant(leftValue <= rightValue); | |
} | |
return base.VisitBinary(node); | |
} | |
protected override Expression VisitUnary(UnaryExpression node) | |
{ | |
var operand = Visit(node.Operand); | |
var operandConst = operand as ConstantExpression; | |
if (operandConst == null) return base.VisitUnary(node); | |
var operandValue = (dynamic) operandConst.Value; | |
switch (node.NodeType) | |
{ | |
case ExpressionType.Not: | |
return Expression.Constant(!operandValue); | |
} | |
return base.VisitUnary(node); | |
} | |
private OptimisedBooleanBinary OptimiseBooleanBinaryExpression( | |
ExpressionType type, Expression left, Expression right) | |
{ | |
Expression leftVisited = null; | |
Expression rightVisited = null; | |
dynamic leftValue = null; | |
dynamic rightValue = null; | |
dynamic GetLeftValue() | |
{ | |
if (leftVisited != null) return leftValue; | |
leftVisited = Visit(left); | |
if (leftVisited is ConstantExpression leftConst) leftValue = leftConst.Value; | |
return leftValue; | |
} | |
dynamic GetRightValue() | |
{ | |
if (rightVisited != null) return rightValue; | |
rightVisited = Visit(right); | |
if (rightVisited is ConstantExpression rightConst) rightValue = rightConst.Value; | |
return rightValue; | |
} | |
switch (type) | |
{ | |
// We can check for constants on each side to simplify the reduction process | |
case ExpressionType.And: | |
case ExpressionType.AndAlso: | |
{ | |
if (GetLeftValue() == false || GetRightValue() == false) return new OptimisedBooleanBinary(false); | |
if (GetLeftValue() == true) return new OptimisedBooleanBinary(rightVisited); | |
if (GetRightValue() == true) return new OptimisedBooleanBinary(leftVisited); | |
break; | |
} | |
case ExpressionType.Or: | |
case ExpressionType.OrElse: | |
{ | |
if (GetLeftValue() == true || GetRightValue() == true) return new OptimisedBooleanBinary(true); | |
if (GetLeftValue() == false) return new OptimisedBooleanBinary(rightVisited); | |
if (GetRightValue() == false) return new OptimisedBooleanBinary(leftVisited); | |
break; | |
} | |
} | |
GetLeftValue(); | |
GetRightValue(); | |
return new OptimisedBooleanBinary | |
{ | |
LeftVisit = leftVisited, | |
RightVisit = rightVisited | |
}; | |
} | |
private class OptimisedBooleanBinary | |
{ | |
public OptimisedBooleanBinary() | |
{ | |
} | |
public OptimisedBooleanBinary(Expression result) | |
{ | |
Result = result; | |
} | |
public OptimisedBooleanBinary(bool result) | |
{ | |
Result = Expression.Constant(result); | |
} | |
public Expression Result { get; set; } | |
public Expression LeftVisit { get; set; } | |
public Expression RightVisit { get; set; } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment