Last active
July 8, 2019 20:42
IQueryable test visitors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Data.Linq.SqlClient; | |
using System.Linq; | |
using System.Linq.Expressions; | |
using System.Reflection; | |
using System.Text; | |
using System.Threading.Tasks; | |
namespace Tests | |
{ | |
public sealed class LikeExpressionVisitor : ExpressionVisitor | |
{ | |
private readonly MethodInfo sqlMethod; | |
private readonly MethodInfo sqlMethodWithEscape; | |
private LikeExpressionVisitor() | |
{ | |
sqlMethod = typeof(SqlMethods).GetMethod(nameof(SqlMethods.Like), new[] { typeof(string), typeof(string) }); | |
sqlMethodWithEscape = typeof(SqlMethods).GetMethod( | |
nameof(SqlMethods.Like), | |
new[] { typeof(string), typeof(string), typeof(char) }); | |
} | |
public static LikeExpressionVisitor Instance { get; } = new LikeExpressionVisitor(); | |
protected override Expression VisitMethodCall(MethodCallExpression node) | |
{ | |
if (node.Method == sqlMethodWithEscape) | |
{ | |
var method = new Func<string, string, char, bool>(SqlExtensions.Like).Method; | |
return Expression.Call(method, node.Arguments); | |
} | |
else if (node.Method == sqlMethod) | |
{ | |
var method = new Func<string, string, bool>(SqlExtensions.Like).Method; | |
return Expression.Call(method, node.Arguments); | |
} | |
return base.VisitMethodCall(node); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Text.RegularExpressions; | |
namespace Tests | |
{ | |
public static class SqlExtensions | |
{ | |
private static readonly Regex likeTransformationRegex | |
= new Regex(@"\.|\$|\{|\(|\||\)|\*|\+|\?|\\", RegexOptions.Compiled); | |
public static bool Like(string matchExpression, string pattern) | |
{ | |
var transformedPattern = likeTransformationRegex | |
.Replace(pattern, ch => @"\" + ch) | |
.Replace('_', '.') | |
.Replace("%", ".*"); | |
var regex = new Regex(@"\A" + transformedPattern + @"\z", RegexOptions.Singleline); | |
return regex.IsMatch(matchExpression); | |
} | |
public static bool Like(string matchExpression, string pattern, char escapeCharacter) | |
{ | |
var escapedPattern = likeTransformationRegex | |
.Replace(pattern, ch => @"\" + ch); | |
var transformedPattern = Regex.Replace( | |
escapedPattern, | |
$"(?<!{escapeCharacter})[_%]", | |
m => m.Value == "_" ? "." : ".*"); | |
var replacementPattern = new Regex($@"{escapeCharacter}(\[|\]|\^|_|%)") | |
.Replace(transformedPattern, "$1") | |
.Replace(new string(escapeCharacter, 2), escapeCharacter.ToString()); | |
var regex = new Regex(@"\A" + replacementPattern + @"\z", RegexOptions.Singleline); | |
return regex.IsMatch(matchExpression); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Diagnostics.CodeAnalysis; | |
using System.Linq; | |
using System.Linq.Expressions; | |
using System.Reflection; | |
namespace Tests | |
{ | |
public sealed class SumExpressionVisitor : ExpressionVisitor | |
{ | |
internal const string EmptyCollectionSumExceptionMessage = "Суммирование пустой коллекции возможно " + | |
"только если возвращаемый тип - Nullable"; | |
private readonly HashSet<MethodInfo> sumMethodInfos; | |
private readonly MethodInfo aggregateOrDefaultIfEmpty; | |
private readonly MethodInfo whereIsNotNull; | |
private readonly MethodInfo select; | |
private readonly MethodInfo assertIsNotEmpty; | |
private SumExpressionVisitor() | |
{ | |
sumMethodInfos = new HashSet<MethodInfo>( | |
typeof(Queryable) | |
.GetMethods() | |
.Where(mi => mi.Name == nameof(Queryable.Sum))); | |
aggregateOrDefaultIfEmpty = typeof(SumExpressionVisitor).GetMethod( | |
nameof(AggregateOrDefaultIfEmpty), | |
BindingFlags.Static | BindingFlags.NonPublic); | |
whereIsNotNull = typeof(SumExpressionVisitor).GetMethod( | |
nameof(WhereIsNotNull), | |
BindingFlags.Static | BindingFlags.NonPublic); | |
select = typeof(Queryable) | |
.GetMethods() | |
.Where(mi => mi.Name == nameof(Queryable.Select)) | |
.Single( | |
mi => | |
{ | |
var sumExpressionType = mi.GetParameters()[1].ParameterType; | |
var sumFuncType = sumExpressionType.GetGenericArguments()[0]; | |
var isSumFuncWithoutIndexer = sumFuncType.GetGenericArguments().Length == 2; | |
return isSumFuncWithoutIndexer; | |
}); | |
assertIsNotEmpty = typeof(SumExpressionVisitor).GetMethod( | |
nameof(ThrowIfEmpty), | |
BindingFlags.Static | BindingFlags.NonPublic); | |
} | |
public static SumExpressionVisitor Instance { get; } = new SumExpressionVisitor(); | |
protected override Expression VisitMethodCall(MethodCallExpression sumMethodCallExpression) | |
{ | |
if (!IsSumMethod(sumMethodCallExpression.Method)) | |
{ | |
return base.VisitMethodCall(sumMethodCallExpression); | |
} | |
var entityType = sumMethodCallExpression.Arguments[0].Type.GetGenericArguments()[0]; | |
var returnType = sumMethodCallExpression.Method.ReturnType; | |
var isNullable = returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Nullable<>); | |
if (!isNullable) | |
{ | |
var specializedAssertIsNotEmpty = assertIsNotEmpty.MakeGenericMethod(entityType); | |
return Expression.Call( | |
sumMethodCallExpression.Method, | |
new[] { Expression.Call(specializedAssertIsNotEmpty, sumMethodCallExpression.Arguments[0]) } | |
.Concat(sumMethodCallExpression.Arguments.Skip(1))); | |
} | |
var collectionExpression = GetCollectionExpression(sumMethodCallExpression, entityType); | |
var sum = typeof(Queryable).GetMethod(nameof(Queryable.Sum), new[] { collectionExpression.Type }); | |
var sumParams = sum.GetParameters().Select(x => Expression.Parameter(x.ParameterType)).ToArray(); | |
var resultNullableUnderlyingType = Nullable.GetUnderlyingType(returnType); | |
return Expression.Call( | |
aggregateOrDefaultIfEmpty.MakeGenericMethod(resultNullableUnderlyingType), | |
Expression.Call( | |
whereIsNotNull.MakeGenericMethod(resultNullableUnderlyingType), | |
collectionExpression), | |
Expression.Lambda(Expression.Call(sum, (IEnumerable<Expression>) sumParams), sumParams)); | |
} | |
private Expression GetCollectionExpression(MethodCallExpression sumMethodCallExpression, Type entityType) | |
{ | |
if (!sumMethodCallExpression.Method.IsGenericMethod) | |
{ | |
return sumMethodCallExpression.Arguments[0]; | |
} | |
var specializedSelect = select.MakeGenericMethod(entityType, sumMethodCallExpression.Method.ReturnType); | |
return Expression.Call(specializedSelect, sumMethodCallExpression.Arguments); | |
} | |
private static IQueryable<T> ThrowIfEmpty<T>(IQueryable<T> source) | |
{ | |
if (!source.Any()) | |
throw new InvalidOperationException(EmptyCollectionSumExceptionMessage); | |
return source; | |
} | |
private static IQueryable<T?> WhereIsNotNull<T>(IQueryable<T?> source) where T : struct | |
{ | |
return source.Where(x => x != null); | |
} | |
private static T? AggregateOrDefaultIfEmpty<T>(IQueryable<T?> source, Func<IQueryable<T?>, T?> aggregateFunc) | |
where T : struct | |
{ | |
if (!source.Any()) | |
return null; | |
return aggregateFunc(source); | |
} | |
private bool IsSumMethod(MethodInfo methodInfo) | |
{ | |
var methodDefinition = !methodInfo.IsGenericMethod ? methodInfo : methodInfo.GetGenericMethodDefinition(); | |
return sumMethodInfos.Contains(methodDefinition); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class StubTableQueryable<TEntity> : IOrderedQueryable<TEntity> | |
{ | |
public IEnumerator<TEntity> GetEnumerator() | |
{ | |
IEnumerable<TEntity> enumerable = new EnumerableQuery<TEntity>(ConvertDataContextExpressionToInMemory(Expression)); | |
return enumerable.GetEnumerator(); | |
} | |
public Expression ConvertDataContextExpressionToInMemory(Expression expression) | |
{ | |
if (expression == null) | |
throw new ArgumentNullException(nameof(expression)); | |
var expressionWithRemplacedSum = SumExpressionVisitor.Instance.Visit(expression); | |
var convertDataContextExpressionToInMemory = LikeExpressionVisitor.Instance.Visit(expressionWithRemplacedSum); | |
return convertDataContextExpressionToInMemory; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment