Skip to content

Instantly share code, notes, and snippets.

@eliottrobson
Last active November 12, 2021 16:44
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eliottrobson/2a75aa0ef4def5e167c44b662ad55b70 to your computer and use it in GitHub Desktop.
Save eliottrobson/2a75aa0ef4def5e167c44b662ad55b70 to your computer and use it in GitHub Desktop.
Boolean Expression Reducer
/// <summary>
/// Reduces boolean expressions to improve performance with generated queries.
/// <example>
/// For example:
/// <code>
/// Expression<Func<User, FollowingData>> projection = u => new FollowingData
/// {
/// UserId = u.Id,
/// FollowersCount = u.Followers.Count,
/// IsFollowing = currentUserId.HasValue && u.Following.Any(f => f.UserId == currentUserId.Value)
/// };
///
/// return await _databaseContext.Users
/// .Select(new BooleanExpressionReducer().Reduce(projection))
/// .ToListAsync();
/// </code>
/// will omit the <c>IsFollowing</c> query (and return false) if <c>currentUserId</c> is null.
/// </example>
/// </summary>
public class BooleanExpressionReducer : ExpressionVisitor
{
public Expression<Func<TSource, TDest>> Reduce<TSource, TDest>(
Expression<Func<TSource, TDest>> expr)
{
return Visit(expr) as Expression<Func<TSource, TDest>>;
}
protected override Expression VisitConditional(ConditionalExpression node)
{
var test = Visit(node.Test);
// The conditional is now a constant, we can replace the branch
if (test is ConstantExpression testNode)
{
var value = (dynamic) testNode.Value;
return value ? Visit(node.IfTrue) : Visit(node.IfFalse);
}
// If it is not a conditional, we follow the default behaviour
return base.VisitConditional(node);
}
protected override Expression VisitMember(MemberExpression node)
{
if (node.Type != typeof(bool))
return base.VisitMember(node);
Expression nodeRoot = node;
while (nodeRoot is MemberExpression nodeRootMember)
nodeRoot = nodeRootMember.Expression;
if (nodeRoot.NodeType != ExpressionType.Constant &&
nodeRoot.NodeType != ExpressionType.MemberAccess)
return base.VisitMember(node);
var objectMember = Expression.Convert(node, typeof(object));
var getterLambda = Expression.Lambda<Func<object>>(objectMember);
var getter = getterLambda.Compile();
var value = getter();
return Expression.Constant(value);
}
protected override Expression VisitBinary(BinaryExpression node)
{
// Special optimisations for boolean expressions
var optimised = OptimiseBooleanBinaryExpression(node.NodeType, node.Left, node.Right);
if (optimised.Result != null) return optimised.Result;
var leftConst = optimised.LeftVisit as ConstantExpression;
var rightConst = optimised.RightVisit as ConstantExpression;
if (leftConst == null || rightConst == null)
return base.VisitBinary(node);
var leftValue = (dynamic) leftConst.Value;
var rightValue = (dynamic) rightConst.Value;
switch (node.NodeType)
{
case ExpressionType.Add:
return Expression.Constant(leftValue + rightValue);
case ExpressionType.Divide:
return Expression.Constant(leftValue / rightValue);
case ExpressionType.Modulo:
return Expression.Constant(leftValue % rightValue);
case ExpressionType.Multiply:
return Expression.Constant(leftValue * rightValue);
case ExpressionType.Power:
return Expression.Constant(leftValue ^ rightValue);
case ExpressionType.Subtract:
return Expression.Constant(leftValue - rightValue);
case ExpressionType.And:
return Expression.Constant(leftValue & rightValue);
case ExpressionType.AndAlso:
return Expression.Constant(leftValue && rightValue);
case ExpressionType.Or:
return Expression.Constant(leftValue | rightValue);
case ExpressionType.OrElse:
return Expression.Constant(leftValue || rightValue);
case ExpressionType.Equal:
return Expression.Constant(leftValue == rightValue);
case ExpressionType.NotEqual:
return Expression.Constant(leftValue != rightValue);
case ExpressionType.GreaterThan:
return Expression.Constant(leftValue > rightValue);
case ExpressionType.GreaterThanOrEqual:
return Expression.Constant(leftValue >= rightValue);
case ExpressionType.LessThan:
return Expression.Constant(leftValue < rightValue);
case ExpressionType.LessThanOrEqual:
return Expression.Constant(leftValue <= rightValue);
}
return base.VisitBinary(node);
}
protected override Expression VisitUnary(UnaryExpression node)
{
var operand = Visit(node.Operand);
var operandConst = operand as ConstantExpression;
if (operandConst == null) return base.VisitUnary(node);
var operandValue = (dynamic) operandConst.Value;
switch (node.NodeType)
{
case ExpressionType.Not:
return Expression.Constant(!operandValue);
}
return base.VisitUnary(node);
}
private OptimisedBooleanBinary OptimiseBooleanBinaryExpression(
ExpressionType type, Expression left, Expression right)
{
Expression leftVisited = null;
Expression rightVisited = null;
dynamic leftValue = null;
dynamic rightValue = null;
dynamic GetLeftValue()
{
if (leftVisited != null) return leftValue;
leftVisited = Visit(left);
if (leftVisited is ConstantExpression leftConst) leftValue = leftConst.Value;
return leftValue;
}
dynamic GetRightValue()
{
if (rightVisited != null) return rightValue;
rightVisited = Visit(right);
if (rightVisited is ConstantExpression rightConst) rightValue = rightConst.Value;
return rightValue;
}
switch (type)
{
// We can check for constants on each side to simplify the reduction process
case ExpressionType.And:
case ExpressionType.AndAlso:
{
if (GetLeftValue() == false || GetRightValue() == false) return new OptimisedBooleanBinary(false);
if (GetLeftValue() == true) return new OptimisedBooleanBinary(rightVisited);
if (GetRightValue() == true) return new OptimisedBooleanBinary(leftVisited);
break;
}
case ExpressionType.Or:
case ExpressionType.OrElse:
{
if (GetLeftValue() == true || GetRightValue() == true) return new OptimisedBooleanBinary(true);
if (GetLeftValue() == false) return new OptimisedBooleanBinary(rightVisited);
if (GetRightValue() == false) return new OptimisedBooleanBinary(leftVisited);
break;
}
}
GetLeftValue();
GetRightValue();
return new OptimisedBooleanBinary
{
LeftVisit = leftVisited,
RightVisit = rightVisited
};
}
private class OptimisedBooleanBinary
{
public OptimisedBooleanBinary()
{
}
public OptimisedBooleanBinary(Expression result)
{
Result = result;
}
public OptimisedBooleanBinary(bool result)
{
Result = Expression.Constant(result);
}
public Expression Result { get; set; }
public Expression LeftVisit { get; set; }
public Expression RightVisit { get; set; }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment