Created
June 23, 2010 10:50
-
-
Save nrk/449765 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Collections.ObjectModel; | |
using System.Linq; | |
using System.Linq.Expressions; | |
namespace Mono.Linq.Expressions { | |
public static class PredicateBuilder { | |
public static Expression<Func<T, bool>> True<T> () { return f => true; } | |
public static Expression<Func<T, bool>> False<T> () { return f => false; } | |
public static Expression<Func<T, bool>> Or<T> (this Expression<Func<T, bool>> expr1, Expression<Func<T, bool>> expr2) | |
{ | |
var parameter = CreateParameterFrom (expr1, expr2); | |
return Expression.Lambda<Func<T, bool>> ( | |
Expression.OrElse ( | |
Rewrite (expr1.Body, parameter), | |
Rewrite (expr2.Body, parameter)), | |
parameter); | |
} | |
public static Expression<Func<T, bool>> And<T> (this Expression<Func<T, bool>> expr1, Expression<Func<T, bool>> expr2) | |
{ | |
var parameter = CreateParameterFrom (expr1, expr2); | |
return Expression.Lambda<Func<T, bool>> ( | |
Expression.AndAlso ( | |
Rewrite (expr1.Body, parameter), | |
Rewrite (expr2.Body, parameter)), | |
parameter); | |
} | |
static ParameterExpression CreateParameterFrom (LambdaExpression left, LambdaExpression right) | |
{ | |
var leftParameters = left.Parameters; | |
var rightParameters = right.Parameters; | |
if (leftParameters.Count != right.Parameters.Count) | |
throw new ArgumentException (); | |
if (leftParameters.Count != 1) | |
throw new ArgumentException (); | |
if (leftParameters [0].Type != rightParameters [0].Type) | |
throw new ArgumentException (); | |
return Expression.Parameter (leftParameters [0].Type, leftParameters [0].Name); | |
} | |
static Expression Rewrite (Expression expression, ParameterExpression parameter) | |
{ | |
return new ParameterRewriter (parameter).Transform (expression); | |
} | |
class ParameterRewriter : ExpressionTransformer { | |
ParameterExpression parameter; | |
public ParameterRewriter (ParameterExpression parameter) | |
{ | |
this.parameter = parameter; | |
} | |
protected override Expression VisitParameter (ParameterExpression expression) | |
{ | |
return parameter; | |
} | |
} | |
} | |
abstract class ExpressionTransformer { | |
public Expression Transform (Expression expression) | |
{ | |
return Visit (expression); | |
} | |
protected virtual Expression Visit (Expression exp) | |
{ | |
if (exp == null) return exp; | |
switch (exp.NodeType) { | |
case ExpressionType.Negate: | |
case ExpressionType.NegateChecked: | |
case ExpressionType.Not: | |
case ExpressionType.Convert: | |
case ExpressionType.ConvertChecked: | |
case ExpressionType.ArrayLength: | |
case ExpressionType.Quote: | |
case ExpressionType.TypeAs: | |
case ExpressionType.UnaryPlus: | |
return this.VisitUnary ((UnaryExpression) exp); | |
case ExpressionType.Add: | |
case ExpressionType.AddChecked: | |
case ExpressionType.Subtract: | |
case ExpressionType.SubtractChecked: | |
case ExpressionType.Multiply: | |
case ExpressionType.MultiplyChecked: | |
case ExpressionType.Divide: | |
case ExpressionType.Power: | |
case ExpressionType.Modulo: | |
case ExpressionType.And: | |
case ExpressionType.AndAlso: | |
case ExpressionType.Or: | |
case ExpressionType.OrElse: | |
case ExpressionType.LessThan: | |
case ExpressionType.LessThanOrEqual: | |
case ExpressionType.GreaterThan: | |
case ExpressionType.GreaterThanOrEqual: | |
case ExpressionType.Equal: | |
case ExpressionType.NotEqual: | |
case ExpressionType.Coalesce: | |
case ExpressionType.ArrayIndex: | |
case ExpressionType.RightShift: | |
case ExpressionType.LeftShift: | |
case ExpressionType.ExclusiveOr: | |
return this.VisitBinary ((BinaryExpression) exp); | |
case ExpressionType.TypeIs: | |
return this.VisitTypeIs ((TypeBinaryExpression) exp); | |
case ExpressionType.Conditional: | |
return this.VisitConditional ((ConditionalExpression) exp); | |
case ExpressionType.Constant: | |
return this.VisitConstant ((ConstantExpression) exp); | |
case ExpressionType.Parameter: | |
return this.VisitParameter ((ParameterExpression) exp); | |
case ExpressionType.MemberAccess: | |
return this.VisitMemberAccess ((MemberExpression) exp); | |
case ExpressionType.Call: | |
return this.VisitMethodCall ((MethodCallExpression) exp); | |
case ExpressionType.Lambda: | |
return this.VisitLambda ((LambdaExpression) exp); | |
case ExpressionType.New: | |
return this.VisitNew ((NewExpression) exp); | |
case ExpressionType.NewArrayInit: | |
case ExpressionType.NewArrayBounds: | |
return this.VisitNewArray ((NewArrayExpression) exp); | |
case ExpressionType.Invoke: | |
return this.VisitInvocation ((InvocationExpression) exp); | |
case ExpressionType.MemberInit: | |
return this.VisitMemberInit ((MemberInitExpression) exp); | |
case ExpressionType.ListInit: | |
return this.VisitListInit ((ListInitExpression) exp); | |
default: | |
throw new Exception (string.Format ("Unhandled expression type: '{0}'", exp.NodeType)); | |
} | |
} | |
protected virtual MemberBinding VisitBinding (MemberBinding binding) | |
{ | |
switch (binding.BindingType) { | |
case MemberBindingType.Assignment: | |
return this.VisitMemberAssignment ((MemberAssignment) binding); | |
case MemberBindingType.MemberBinding: | |
return this.VisitMemberMemberBinding ((MemberMemberBinding) binding); | |
case MemberBindingType.ListBinding: | |
return this.VisitMemberListBinding ((MemberListBinding) binding); | |
default: | |
throw new Exception (string.Format ("Unhandled binding type '{0}'", binding.BindingType)); | |
} | |
} | |
protected virtual ElementInit VisitElementInitializer (ElementInit initializer) | |
{ | |
ReadOnlyCollection<Expression> arguments = this.VisitExpressionList (initializer.Arguments); | |
if (arguments != initializer.Arguments) return Expression.ElementInit (initializer.AddMethod, arguments); | |
return initializer; | |
} | |
protected virtual Expression VisitUnary (UnaryExpression u) | |
{ | |
Expression operand = this.Visit (u.Operand); | |
if (operand != u.Operand) return Expression.MakeUnary (u.NodeType, operand, u.Type, u.Method); | |
return u; | |
} | |
protected virtual Expression VisitBinary (BinaryExpression b) | |
{ | |
Expression left = this.Visit (b.Left); | |
Expression right = this.Visit (b.Right); | |
Expression conversion = this.Visit (b.Conversion); | |
if (left != b.Left || right != b.Right || conversion != b.Conversion) { | |
if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null) { | |
return Expression.Coalesce (left, right, conversion as LambdaExpression); | |
} else { | |
return Expression.MakeBinary (b.NodeType, left, right, b.IsLiftedToNull, b.Method); | |
} | |
} | |
return b; | |
} | |
protected virtual Expression VisitTypeIs (TypeBinaryExpression b) | |
{ | |
Expression expr = this.Visit (b.Expression); | |
if (expr != b.Expression) { | |
return Expression.TypeIs (expr, b.TypeOperand); | |
} | |
return b; | |
} | |
protected virtual Expression VisitConstant (ConstantExpression c) | |
{ | |
return c; | |
} | |
protected virtual Expression VisitConditional (ConditionalExpression c) | |
{ | |
Expression test = this.Visit (c.Test); | |
Expression ifTrue = this.Visit (c.IfTrue); | |
Expression ifFalse = this.Visit (c.IfFalse); | |
if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) { | |
return Expression.Condition (test, ifTrue, ifFalse); | |
} | |
return c; | |
} | |
protected virtual Expression VisitParameter (ParameterExpression p) | |
{ | |
return p; | |
} | |
protected virtual Expression VisitMemberAccess (MemberExpression m) | |
{ | |
Expression exp = this.Visit (m.Expression); | |
if (exp != m.Expression) { | |
return Expression.MakeMemberAccess (exp, m.Member); | |
} | |
return m; | |
} | |
protected virtual Expression VisitMethodCall (MethodCallExpression m) | |
{ | |
Expression obj = this.Visit (m.Object); | |
IEnumerable<Expression> args = this.VisitExpressionList (m.Arguments); | |
if (obj != m.Object || args != m.Arguments) { | |
return Expression.Call (obj, m.Method, args); | |
} | |
return m; | |
} | |
protected virtual ReadOnlyCollection<Expression> VisitExpressionList (ReadOnlyCollection<Expression> original) | |
{ | |
var list = VisitList (original, Visit); | |
if (list == null) return original; | |
return new ReadOnlyCollection<Expression> (list); | |
} | |
protected virtual MemberAssignment VisitMemberAssignment (MemberAssignment assignment) | |
{ | |
Expression e = this.Visit (assignment.Expression); | |
if (e != assignment.Expression) return Expression.Bind (assignment.Member, e); | |
return assignment; | |
} | |
protected virtual MemberMemberBinding VisitMemberMemberBinding (MemberMemberBinding binding) | |
{ | |
IEnumerable<MemberBinding> bindings = this.VisitBindingList (binding.Bindings); | |
if (bindings != binding.Bindings) return Expression.MemberBind (binding.Member, bindings); | |
return binding; | |
} | |
protected virtual MemberListBinding VisitMemberListBinding (MemberListBinding binding) | |
{ | |
IEnumerable<ElementInit> initializers = this.VisitElementInitializerList (binding.Initializers); | |
if (initializers != binding.Initializers) return Expression.ListBind (binding.Member, initializers); | |
return binding; | |
} | |
protected virtual IEnumerable<MemberBinding> VisitBindingList (ReadOnlyCollection<MemberBinding> original) | |
{ | |
return VisitList (original, VisitBinding); | |
} | |
protected virtual IEnumerable<ElementInit> VisitElementInitializerList (ReadOnlyCollection<ElementInit> original) | |
{ | |
return VisitList (original, VisitElementInitializer); | |
} | |
private IList<TElement> VisitList<TElement> (ReadOnlyCollection<TElement> original, Func<TElement, TElement> visit) | |
{ | |
List<TElement> list = null; | |
for (int i = 0, n = original.Count; i < n; i++) { | |
TElement element = visit (original [i]); | |
if (list != null) { | |
list.Add (element); | |
} else if (!EqualityComparer<TElement>.Default.Equals (element, original [i])) { | |
list = new List<TElement> (n); | |
for (int j = 0; j < i; j++) { | |
list.Add (original [j]); | |
} | |
list.Add (element); | |
} | |
} | |
if (list != null) | |
return list; | |
return original; | |
} | |
protected virtual Expression VisitLambda (LambdaExpression lambda) | |
{ | |
Expression body = this.Visit (lambda.Body); | |
if (body != lambda.Body) return Expression.Lambda (lambda.Type, body, lambda.Parameters); | |
return lambda; | |
} | |
protected virtual NewExpression VisitNew (NewExpression nex) | |
{ | |
IEnumerable<Expression> args = this.VisitExpressionList (nex.Arguments); | |
if (args != nex.Arguments) { | |
if (nex.Members != null) | |
return Expression.New (nex.Constructor, args, nex.Members); | |
else | |
return Expression.New (nex.Constructor, args); | |
} | |
return nex; | |
} | |
protected virtual Expression VisitMemberInit (MemberInitExpression init) | |
{ | |
NewExpression n = this.VisitNew (init.NewExpression); | |
IEnumerable<MemberBinding> bindings = this.VisitBindingList (init.Bindings); | |
if (n != init.NewExpression || bindings != init.Bindings) return Expression.MemberInit (n, bindings); | |
return init; | |
} | |
protected virtual Expression VisitListInit (ListInitExpression init) | |
{ | |
NewExpression n = this.VisitNew (init.NewExpression); | |
IEnumerable<ElementInit> initializers = this.VisitElementInitializerList (init.Initializers); | |
if (n != init.NewExpression || initializers != init.Initializers) return Expression.ListInit (n, initializers); | |
return init; | |
} | |
protected virtual Expression VisitNewArray (NewArrayExpression na) | |
{ | |
IEnumerable<Expression> exprs = this.VisitExpressionList (na.Expressions); | |
if (exprs != na.Expressions) { | |
if (na.NodeType == ExpressionType.NewArrayInit) { | |
return Expression.NewArrayInit (na.Type.GetElementType (), exprs); | |
} else { | |
return Expression.NewArrayBounds (na.Type.GetElementType (), exprs); | |
} | |
} | |
return na; | |
} | |
protected virtual Expression VisitInvocation (InvocationExpression iv) | |
{ | |
IEnumerable<Expression> args = this.VisitExpressionList (iv.Arguments); | |
Expression expr = this.Visit (iv.Expression); | |
if (args != iv.Arguments || expr != iv.Expression) return Expression.Invoke (expr, args); | |
return iv; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
this fucking shit never works