Skip to content

Instantly share code, notes, and snippets.

@ErikEJ
Last active November 24, 2023 15:58
  • Star 31 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
Replacement for EF Core .Contains, that avoids SQL Server plan cache pollution
using System.Linq.Expressions;
namespace Microsoft.EntityFrameworkCore
{
public static class IQueryableExtensions
{
public static IQueryable<TQuery> In<TKey, TQuery>(
this IQueryable<TQuery> queryable,
IEnumerable<TKey> values,
Expression<Func<TQuery, TKey>> keySelector)
{
ArgumentNullException.ThrowIfNull(values);
ArgumentNullException.ThrowIfNull(keySelector);
if (!values.Any())
{
return queryable.Take(0);
}
var distinctValues = Bucketize(values);
if (distinctValues.Length > 2048)
{
throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector));
}
var expr = CreateBalancedORExpression(distinctValues, keySelector.Body, 0, distinctValues.Length - 1);
var clause = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters);
return queryable.Where(clause);
}
private static BinaryExpression CreateBalancedORExpression<TKey>(TKey[] values, Expression keySelectorBody, int start, int end)
{
if (start == end)
{
var v1 = values[start];
return Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body);
}
else if (start + 1 == end)
{
var v1 = values[start];
var v2 = values[end];
return Expression.OrElse(
Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body),
Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v2)).Body));
}
else
{
int mid = (start + end) / 2;
return Expression.OrElse(
CreateBalancedORExpression(values, keySelectorBody, start, mid),
CreateBalancedORExpression(values, keySelectorBody, mid + 1, end));
}
}
private static TKey[] Bucketize<TKey>(IEnumerable<TKey> values)
{
var distinctValues = new HashSet<TKey>(values).ToArray();
var originalLength = distinctValues.Length;
int bucket = (int)Math.Pow(2, Math.Ceiling(Math.Log(originalLength, 2)));
if (originalLength == bucket) return distinctValues;
var lastValue = distinctValues[originalLength - 1];
Array.Resize(ref distinctValues, bucket);
distinctValues.AsSpan().Slice(originalLength).Fill(lastValue);
return distinctValues;
}
}
}
@moredatapls
Copy link

Great work! Could you specify the license of this piece of code?

@ErikEJ
Copy link
Author

ErikEJ commented Jun 2, 2020

There is no license, and no license restrictions - use as you see fit!

@alb-xss
Copy link

alb-xss commented Dec 27, 2020

This is absolutely brilliant, the buckets idea is genius. Thank you for sharing your solution

@ErikEJ
Copy link
Author

ErikEJ commented Feb 17, 2021

@OculiViridi
Copy link

OculiViridi commented Jun 4, 2021

@ErikEJ Great piece of code! 😄

Is it possible to add this extension method in EF.Functions package? Here you can see also the Microsoft documentation.

I think it would be even more nice if it were possible to implement it as the EF.Functions.Like() method, to obtain ad usage like this:

var result = dbContext.MyEntities
    .Where(e => EF.Functions.In(values => e.Id))
    .ToList();

@h0wXD
Copy link

h0wXD commented Aug 13, 2021

I'm not sure what you're trying to achieve with the bucketize, but I think it can be much simpler, and only inlines once it exceeds the parameters:

// SQL Server max parameter count is 2100, keep space
private const int MaxSqlParameterCount = 2048;

// https://github.com/dotnet/efcore/issues/13617
public static IQueryable<TQuery> WhereIn<TKey, TQuery>(
    this IQueryable<TQuery> queryable,
    IEnumerable<TKey> values,
    Expression<Func<TQuery, TKey>> keySelector)
{
    if (values == null) throw new ArgumentNullException(nameof(values));
    if (keySelector == null) throw new ArgumentNullException(nameof(keySelector));
    if (!values.Any()) return queryable.Take(0);

    var distinctValues = values.Distinct().ToList();
    Expression expr = null;
    Expression<Func<TKey>> valueAsExpression = null;

    for (var i = 0; i < distinctValues.Count; i++)
    {
        var id = distinctValues[i];

        valueAsExpression = () => id;
        expr = expr == null
            ? Expression.Equal(keySelector.Body, valueAsExpression.Body)
            : i < MaxSqlParameterCount
                ? Expression.OrElse(expr, Expression.Equal(keySelector.Body, valueAsExpression.Body))
                : Expression.OrElse(expr, Expression.Equal(keySelector.Body, Expression.Constant(id)));
    }

    Bucketize(distinctValues.Count, () =>
        expr = Expression.OrElse(expr, Expression.Equal(keySelector.Body, valueAsExpression.Body)));

    var lambda = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters);

    return queryable.Where(lambda);
}

private static void Bucketize(int parameterCount, Action padExpression)
{
    if (parameterCount >= MaxSqlParameterCount)
    {
        return;
    }

    var bucket = 1;
    while (parameterCount > bucket)
    {
        bucket *= 2;
    }

    for (var i = parameterCount; i < bucket; i++)
    {
        padExpression();
    }
}

Updated to bucketize and limit the number of plans

@ErikEJ
Copy link
Author

ErikEJ commented Aug 13, 2021

@h0wXD Buckets are important in order to limit the number of plans, see my blog post here: https://erikej.github.io/efcore/sqlserver/2020/03/30/ef-core-cache-pollution.html

@RichardD2
Copy link

Couldn't you get rid of the PairWise function, and replace the loop:

while (predicates.Count > 1)
{
    predicates = PairWise(predicates).Select(p => Expression.OrElse(p.Item1, p.Item2)).ToList();
}

var body = predicates.Single();

with:

var body = predicates.Aggregate(Expression.OrElse);

@RichardD2
Copy link

RichardD2 commented Aug 24, 2021

Or, if you're worried about the deep nesting of conditions:

private readonly struct HalfList<T>
{
    private readonly IReadOnlyList<T> _list;
    private readonly int _startIndex;

    private HalfList(IReadOnlyList<T> list, int startIndex, int count)
    {
        _list = list ?? throw new ArgumentNullException(nameof(list));
        _startIndex = startIndex;
        Count = count;
    }

    public HalfList(IReadOnlyList<T> list) : this(list, 0, list.Count)
    {
    }

    public int Count { get; }
    
    public T Item => Count == 1 ? _list[_startIndex] : throw new InvalidOperationException();

    public (HalfList<T> left, HalfList<T> right) Split()
    {
        if (Count < 2) throw new InvalidOperationException();

        int pivot = Count >> 1;
        var left = new HalfList<T>(_list, _startIndex, pivot);
        var right = new HalfList<T>(_list, _startIndex + pivot, Count - pivot);
        return (left, right);
    }
}

private static Expression CombinePredicates(IReadOnlyList<Expression> parts, Func<Expression, Expression, Expression> fn)
{
    if (parts.Count == 0) throw new ArgumentException("At least one part is required.", nameof(parts));
    if (parts.Count == 1) return parts[0];
    
    var segment = new HalfList<Expression>(parts);
    return CombineCore(segment.Split(), fn);

    static Expression CombineCore((HalfList<Expression> left, HalfList<Expression> right) x, Func<Expression, Expression, Expression> fn)
    {
        var left = x.left.Count == 1 ? x.left.Item : CombineCore(x.left.Split(), fn);
        var right = x.right.Count == 1 ? x.right.Item : CombineCore(x.right.Split(), fn);
        return fn(left, right);
    }
}
var body = CombinePredicates(predicates, Expression.OrElse);

@joelmandell
Copy link

I had to use the IQueryableExtensions.In(), to support older SQL Server versions.
And had an issue when using a library that parses the table name from generated query (I am using multiple Interceptors). The issue is when IEnumerable<TKey> values is empty.

It has the check if(!values.Any()) and returns queryable.Take(0).
When that call is used EF Core generates an subquery, and the library that I use (EFCoreSecondLevelCacheInterceptor) fails trying to fetch table name from query.

I resorted to return queryable.Where(x => true) instead. And that seems to work after that. Posting this in case someone else has this problem.

Updated code of that extension method, according to that change:

public static IQueryable<TQuery> In<TKey, TQuery>(
            this IQueryable<TQuery> queryable,
            IEnumerable<TKey> values,
            Expression<Func<TQuery, TKey>> keySelector)
        {
            if (values == null)
            {
                throw new ArgumentNullException(nameof(values));
            }

            if (keySelector == null)
            {
                throw new ArgumentNullException(nameof(keySelector));
            }

            if (!values.Any())
            {
                //.Where instead of .Take(0), cause that seem to produce "funky" SQL.
                return queryable.Where(x => true);
            }

            var distinctValues = Bucketize(values);

            if (distinctValues.Length > 2048)
            {
                throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector));
            }

            var predicates = distinctValues
                .Select(v =>
                {
                    // Create an expression that captures the variable so EF can turn this into a parameterized SQL query
                    Expression<Func<TKey>> valueAsExpression = () => v;
                    return Expression.Equal(keySelector.Body, valueAsExpression.Body);
                })
                .ToList();

            while (predicates.Count > 1)
            {
                predicates = PairWise(predicates).Select(p => Expression.OrElse(p.Item1, p.Item2)).ToList();
            }

            var body = predicates.Single();

            var clause = Expression.Lambda<Func<TQuery, bool>>(body, keySelector.Parameters);

            return queryable.Where(clause);
        }

@yv989c
Copy link

yv989c commented Dec 29, 2021

@ErikEJ I believe you may find my project useful. It solves this problem in a flexible way. I have been using a similar strategy in my work with acceptable results, so I put some effort and made a generic version of it for everyone to use. Please, feel free to code review it if you have time. Ideas are welcome.

@yv989c
Copy link

yv989c commented Dec 29, 2021

Hey @joelmandell. Please take a look ☝️ too 🙂.

@fiseni
Copy link

fiseni commented Mar 12, 2023

Few updates, it performs a bit better.
We can do further improvements but we'll need to implement caching. So, it defies the purpose, too many cache layers.

public static class IQueryableExtensions
{
    public static IQueryable<TQuery> In<TKey, TQuery>(
        this IQueryable<TQuery> queryable,
        IEnumerable<TKey> values,
        Expression<Func<TQuery, TKey>> keySelector)
    {
        if (values == null)
        {
            throw new ArgumentNullException(nameof(values));
        }

        if (keySelector == null)
        {
            throw new ArgumentNullException(nameof(keySelector));
        }

        if (!values.Any())
        {
            return queryable.Take(0);
        }

        var distinctValues = Bucketize(values);

        if (distinctValues.Length > 2048)
        {
            throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector));
        }

        var expr = CreateBalancedORExpression(distinctValues, keySelector.Body, 0, distinctValues.Length - 1);

        var clause = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters);

        return queryable.Where(clause);
    }

    private static BinaryExpression CreateBalancedORExpression<TKey>(TKey[] values, Expression keySelectorBody, int start, int end)
    {
        if (start == end)
        {
            var v1 = values[start];
            return Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body);
        }
        else if (start + 1 == end)
        {
            var v1 = values[start];
            var v2 = values[end];

            return Expression.OrElse(
                Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body),
                Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v2)).Body));
        }
        else
        {
            int mid = (start + end) / 2;
            return Expression.OrElse(
                CreateBalancedORExpression(values, keySelectorBody, start, mid),
                CreateBalancedORExpression(values, keySelectorBody, mid + 1, end));
        }
    }

    private static TKey[] Bucketize<TKey>(IEnumerable<TKey> values)
    {
        var distinctValues = new HashSet<TKey>(values).ToArray();
        var originalLength = distinctValues.Length;

        int bucket = (int)Math.Pow(2, Math.Ceiling(Math.Log(originalLength, 2)));

        if (originalLength == bucket) return distinctValues;

        var lastValue = distinctValues[originalLength - 1];
        Array.Resize(ref distinctValues, bucket);
        distinctValues.AsSpan().Slice(originalLength).Fill(lastValue);

        return distinctValues;
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment