Skip to content

Instantly share code, notes, and snippets.

@lasandell
Created April 10, 2013 17:03
Show Gist options
  • Save lasandell/5356474 to your computer and use it in GitHub Desktop.
Save lasandell/5356474 to your computer and use it in GitHub Desktop.
Class to compare two lambda expressions so that methods taking expressions as arguments can be mocked with NSubstitute.
/// <summary>
/// Compare two expressions so that methods taking lambda
/// expressions as arguments can be mocked with NSubstitute.
/// <example>
/// _ordersRepository.Find(Expr.Is&lt;Order&gt;(o => o.Id == 2)).Returns(order2)
/// </example>
/// </summary>
public static class Expr
{
public static Expression<Func<TArg, bool>> Is<TArg>(Expression<Func<TArg, bool>> expression)
{
// Delegate to NSubstitute Arg.Is() method
return Arg.Is((Expression<Func<TArg, bool>> e) =>
ExpressionComparer.Equals(expression, e));
}
}
/// <summary>
/// Helper to determine if two expressions are equivalent.
/// Currently only understands a subset of the Expression types,
/// and does know about commutativity, associativity, etc. It only
/// compares the inherent properties of the expression itself and not
/// its children since they will be compared seperately.
/// </summary>
internal static class ExpressionComparer
{
public static bool Equals<TDelegate>(Expression<TDelegate> left, Expression<TDelegate> right)
{
var leftExpressions = ExpressionExplorer.GetSubExpressions(
ConstantReducer.Reduce(left.Body));
var rightExpressions = ExpressionExplorer.GetSubExpressions(
ConstantReducer.Reduce(right.Body));
if (leftExpressions.Count() == rightExpressions.Count())
{
return leftExpressions.Zip(rightExpressions, Tuple.Create)
.All(t => ExpressionEquals(t.Item1, t.Item2, left.Parameters, right.Parameters));
}
return false;
}
private static bool ExpressionEquals(
Expression left, Expression right,
ReadOnlyCollection<ParameterExpression> leftParams,
ReadOnlyCollection<ParameterExpression> rightParams)
{
if (left.NodeType == right.NodeType)
{
var constantLeft = left as ConstantExpression;
if (constantLeft != null)
{
var constantRight = (ConstantExpression)right;
return constantLeft.Type == constantRight.Type
&& Equals(constantLeft.Value, constantRight.Value);
}
var binaryLeft = left as BinaryExpression;
if (binaryLeft != null)
{
var binaryRight = (BinaryExpression)right;
return binaryLeft.Method == binaryRight.Method
&& binaryLeft.Conversion == binaryLeft.Conversion;
}
var memberLeft = left as MemberExpression;
if (memberLeft != null)
{
var memberRight = (MemberExpression)right;
return memberLeft.Member == memberRight.Member;
}
var paramLeft = left as ParameterExpression;
if (paramLeft != null)
{
var paramRight = (ParameterExpression)right;
return leftParams.IndexOf(paramLeft) == rightParams.IndexOf(paramRight);
}
throw new Exception("Don't know how to compare these expressions.");
}
return false;
}
}
/// <summary>
/// Helper to "flatten" an expression tree into a list of
/// nodes contained within the tree. Nodes are returned in
/// infix order (parents before children).
/// </summary>
internal class ExpressionExplorer : ExpressionVisitor
{
private readonly List<Expression> _expressionList;
private ExpressionExplorer()
{
_expressionList = new List<Expression>();
}
public override Expression Visit(Expression node)
{
_expressionList.Add(node);
return base.Visit(node);
}
public static List<Expression> GetSubExpressions(Expression expression)
{
var instance = new ExpressionExplorer();
instance.Visit(expression);
return instance._expressionList;
}
}
/// <summary>
/// Helper to replace subexpressions that can be evaluated
/// immediately with equivalent constants. This is needed because an
/// expression like o => o.Id = orderId will reference orderId on a
/// generated closure type rather than the constant value of orderId
/// when the expression was created, which would cause our comparisons
/// to fail.
/// </summary>
internal class ConstantReducer : ExpressionVisitor
{
private ConstantReducer()
{
}
public static Expression Reduce(Expression expression)
{
return new ConstantReducer().Visit(expression);
}
[DebuggerNonUserCode]
public override Expression Visit(Expression expr)
{
if (expr.NodeType != ExpressionType.Constant)
{
try
{
// Compile subexpression to delegate and evaluate it
var value = Expression.Lambda(expr).Compile().DynamicInvoke();
return Expression.Constant(value, expr.Type);
}
catch
{
// Subexpression failed to compile to nullary delegate, meaning it
// contained a parameter somewhere and so cannot be reduced to a constant.
}
}
return base.Visit(expr);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment