using System.Linq.Expressions; | |
namespace Microsoft.EntityFrameworkCore | |
{ | |
public static class IQueryableExtensions | |
{ | |
public static IQueryable<TQuery> In<TKey, TQuery>( | |
this IQueryable<TQuery> queryable, | |
IEnumerable<TKey> values, | |
Expression<Func<TQuery, TKey>> keySelector) | |
{ | |
ArgumentNullException.ThrowIfNull(values); | |
ArgumentNullException.ThrowIfNull(keySelector); | |
if (!values.Any()) | |
{ | |
return queryable.Take(0); | |
} | |
var distinctValues = Bucketize(values); | |
if (distinctValues.Length > 2048) | |
{ | |
throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector)); | |
} | |
var expr = CreateBalancedORExpression(distinctValues, keySelector.Body, 0, distinctValues.Length - 1); | |
var clause = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters); | |
return queryable.Where(clause); | |
} | |
private static BinaryExpression CreateBalancedORExpression<TKey>(TKey[] values, Expression keySelectorBody, int start, int end) | |
{ | |
if (start == end) | |
{ | |
var v1 = values[start]; | |
return Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body); | |
} | |
else if (start + 1 == end) | |
{ | |
var v1 = values[start]; | |
var v2 = values[end]; | |
return Expression.OrElse( | |
Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body), | |
Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v2)).Body)); | |
} | |
else | |
{ | |
int mid = (start + end) / 2; | |
return Expression.OrElse( | |
CreateBalancedORExpression(values, keySelectorBody, start, mid), | |
CreateBalancedORExpression(values, keySelectorBody, mid + 1, end)); | |
} | |
} | |
private static TKey[] Bucketize<TKey>(IEnumerable<TKey> values) | |
{ | |
var distinctValues = new HashSet<TKey>(values).ToArray(); | |
var originalLength = distinctValues.Length; | |
int bucket = (int)Math.Pow(2, Math.Ceiling(Math.Log(originalLength, 2))); | |
if (originalLength == bucket) return distinctValues; | |
var lastValue = distinctValues[originalLength - 1]; | |
Array.Resize(ref distinctValues, bucket); | |
distinctValues.AsSpan().Slice(originalLength).Fill(lastValue); | |
return distinctValues; | |
} | |
} | |
} |
There is no license, and no license restrictions - use as you see fit!
This is absolutely brilliant, the buckets idea is genius. Thank you for sharing your solution
@blogcraft ???
@ErikEJ Great piece of code! 😄
Is it possible to add this extension method in EF.Functions package? Here you can see also the Microsoft documentation.
I think it would be even more nice if it were possible to implement it as the EF.Functions.Like()
method, to obtain ad usage like this:
var result = dbContext.MyEntities
.Where(e => EF.Functions.In(values => e.Id))
.ToList();
I'm not sure what you're trying to achieve with the bucketize, but I think it can be much simpler, and only inlines once it exceeds the parameters:
// SQL Server max parameter count is 2100, keep space
private const int MaxSqlParameterCount = 2048;
// https://github.com/dotnet/efcore/issues/13617
public static IQueryable<TQuery> WhereIn<TKey, TQuery>(
this IQueryable<TQuery> queryable,
IEnumerable<TKey> values,
Expression<Func<TQuery, TKey>> keySelector)
{
if (values == null) throw new ArgumentNullException(nameof(values));
if (keySelector == null) throw new ArgumentNullException(nameof(keySelector));
if (!values.Any()) return queryable.Take(0);
var distinctValues = values.Distinct().ToList();
Expression expr = null;
Expression<Func<TKey>> valueAsExpression = null;
for (var i = 0; i < distinctValues.Count; i++)
{
var id = distinctValues[i];
valueAsExpression = () => id;
expr = expr == null
? Expression.Equal(keySelector.Body, valueAsExpression.Body)
: i < MaxSqlParameterCount
? Expression.OrElse(expr, Expression.Equal(keySelector.Body, valueAsExpression.Body))
: Expression.OrElse(expr, Expression.Equal(keySelector.Body, Expression.Constant(id)));
}
Bucketize(distinctValues.Count, () =>
expr = Expression.OrElse(expr, Expression.Equal(keySelector.Body, valueAsExpression.Body)));
var lambda = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters);
return queryable.Where(lambda);
}
private static void Bucketize(int parameterCount, Action padExpression)
{
if (parameterCount >= MaxSqlParameterCount)
{
return;
}
var bucket = 1;
while (parameterCount > bucket)
{
bucket *= 2;
}
for (var i = parameterCount; i < bucket; i++)
{
padExpression();
}
}
Updated to bucketize and limit the number of plans
@h0wXD Buckets are important in order to limit the number of plans, see my blog post here: https://erikej.github.io/efcore/sqlserver/2020/03/30/ef-core-cache-pollution.html
Couldn't you get rid of the PairWise
function, and replace the loop:
while (predicates.Count > 1)
{
predicates = PairWise(predicates).Select(p => Expression.OrElse(p.Item1, p.Item2)).ToList();
}
var body = predicates.Single();
with:
var body = predicates.Aggregate(Expression.OrElse);
Or, if you're worried about the deep nesting of conditions:
private readonly struct HalfList<T>
{
private readonly IReadOnlyList<T> _list;
private readonly int _startIndex;
private HalfList(IReadOnlyList<T> list, int startIndex, int count)
{
_list = list ?? throw new ArgumentNullException(nameof(list));
_startIndex = startIndex;
Count = count;
}
public HalfList(IReadOnlyList<T> list) : this(list, 0, list.Count)
{
}
public int Count { get; }
public T Item => Count == 1 ? _list[_startIndex] : throw new InvalidOperationException();
public (HalfList<T> left, HalfList<T> right) Split()
{
if (Count < 2) throw new InvalidOperationException();
int pivot = Count >> 1;
var left = new HalfList<T>(_list, _startIndex, pivot);
var right = new HalfList<T>(_list, _startIndex + pivot, Count - pivot);
return (left, right);
}
}
private static Expression CombinePredicates(IReadOnlyList<Expression> parts, Func<Expression, Expression, Expression> fn)
{
if (parts.Count == 0) throw new ArgumentException("At least one part is required.", nameof(parts));
if (parts.Count == 1) return parts[0];
var segment = new HalfList<Expression>(parts);
return CombineCore(segment.Split(), fn);
static Expression CombineCore((HalfList<Expression> left, HalfList<Expression> right) x, Func<Expression, Expression, Expression> fn)
{
var left = x.left.Count == 1 ? x.left.Item : CombineCore(x.left.Split(), fn);
var right = x.right.Count == 1 ? x.right.Item : CombineCore(x.right.Split(), fn);
return fn(left, right);
}
}
var body = CombinePredicates(predicates, Expression.OrElse);
I had to use the IQueryableExtensions.In(), to support older SQL Server versions.
And had an issue when using a library that parses the table name from generated query (I am using multiple Interceptors). The issue is when IEnumerable<TKey> values
is empty.
It has the check if(!values.Any())
and returns queryable.Take(0)
.
When that call is used EF Core generates an subquery, and the library that I use (EFCoreSecondLevelCacheInterceptor) fails trying to fetch table name from query.
I resorted to return queryable.Where(x => true)
instead. And that seems to work after that. Posting this in case someone else has this problem.
Updated code of that extension method, according to that change:
public static IQueryable<TQuery> In<TKey, TQuery>(
this IQueryable<TQuery> queryable,
IEnumerable<TKey> values,
Expression<Func<TQuery, TKey>> keySelector)
{
if (values == null)
{
throw new ArgumentNullException(nameof(values));
}
if (keySelector == null)
{
throw new ArgumentNullException(nameof(keySelector));
}
if (!values.Any())
{
//.Where instead of .Take(0), cause that seem to produce "funky" SQL.
return queryable.Where(x => true);
}
var distinctValues = Bucketize(values);
if (distinctValues.Length > 2048)
{
throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector));
}
var predicates = distinctValues
.Select(v =>
{
// Create an expression that captures the variable so EF can turn this into a parameterized SQL query
Expression<Func<TKey>> valueAsExpression = () => v;
return Expression.Equal(keySelector.Body, valueAsExpression.Body);
})
.ToList();
while (predicates.Count > 1)
{
predicates = PairWise(predicates).Select(p => Expression.OrElse(p.Item1, p.Item2)).ToList();
}
var body = predicates.Single();
var clause = Expression.Lambda<Func<TQuery, bool>>(body, keySelector.Parameters);
return queryable.Where(clause);
}
@ErikEJ I believe you may find my project useful. It solves this problem in a flexible way. I have been using a similar strategy in my work with acceptable results, so I put some effort and made a generic version of it for everyone to use. Please, feel free to code review it if you have time. Ideas are welcome.
Hey @joelmandell. Please take a look ☝️ too 🙂.
Few updates, it performs a bit better.
We can do further improvements but we'll need to implement caching. So, it defies the purpose, too many cache layers.
public static class IQueryableExtensions
{
public static IQueryable<TQuery> In<TKey, TQuery>(
this IQueryable<TQuery> queryable,
IEnumerable<TKey> values,
Expression<Func<TQuery, TKey>> keySelector)
{
if (values == null)
{
throw new ArgumentNullException(nameof(values));
}
if (keySelector == null)
{
throw new ArgumentNullException(nameof(keySelector));
}
if (!values.Any())
{
return queryable.Take(0);
}
var distinctValues = Bucketize(values);
if (distinctValues.Length > 2048)
{
throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector));
}
var expr = CreateBalancedORExpression(distinctValues, keySelector.Body, 0, distinctValues.Length - 1);
var clause = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters);
return queryable.Where(clause);
}
private static BinaryExpression CreateBalancedORExpression<TKey>(TKey[] values, Expression keySelectorBody, int start, int end)
{
if (start == end)
{
var v1 = values[start];
return Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body);
}
else if (start + 1 == end)
{
var v1 = values[start];
var v2 = values[end];
return Expression.OrElse(
Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body),
Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v2)).Body));
}
else
{
int mid = (start + end) / 2;
return Expression.OrElse(
CreateBalancedORExpression(values, keySelectorBody, start, mid),
CreateBalancedORExpression(values, keySelectorBody, mid + 1, end));
}
}
private static TKey[] Bucketize<TKey>(IEnumerable<TKey> values)
{
var distinctValues = new HashSet<TKey>(values).ToArray();
var originalLength = distinctValues.Length;
int bucket = (int)Math.Pow(2, Math.Ceiling(Math.Log(originalLength, 2)));
if (originalLength == bucket) return distinctValues;
var lastValue = distinctValues[originalLength - 1];
Array.Resize(ref distinctValues, bucket);
distinctValues.AsSpan().Slice(originalLength).Fill(lastValue);
return distinctValues;
}
}
Great work! Could you specify the license of this piece of code?