Skip to content

Instantly share code, notes, and snippets.

@ErikEJ
Last active February 29, 2024 12:16
Show Gist options
  • Star 36 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save ErikEJ/6ab62e8b9c226ecacf02a5e5713ff7bd to your computer and use it in GitHub Desktop.
Save ErikEJ/6ab62e8b9c226ecacf02a5e5713ff7bd to your computer and use it in GitHub Desktop.
Replacement for EF Core .Contains, that avoids SQL Server plan cache pollution
using System.Linq.Expressions;
namespace Microsoft.EntityFrameworkCore
{
public static class IQueryableExtensions
{
public static IQueryable<TQuery> In<TKey, TQuery>(
this IQueryable<TQuery> queryable,
IEnumerable<TKey> values,
Expression<Func<TQuery, TKey>> keySelector)
{
ArgumentNullException.ThrowIfNull(values);
ArgumentNullException.ThrowIfNull(keySelector);
if (!values.Any())
{
return queryable.Take(0);
}
var distinctValues = Bucketize(values);
if (distinctValues.Length > 2048)
{
throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector));
}
var expr = CreateBalancedORExpression(distinctValues, keySelector.Body, 0, distinctValues.Length - 1);
var clause = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters);
return queryable.Where(clause);
}
private static BinaryExpression CreateBalancedORExpression<TKey>(TKey[] values, Expression keySelectorBody, int start, int end)
{
if (start == end)
{
var v1 = values[start];
return Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body);
}
else if (start + 1 == end)
{
var v1 = values[start];
var v2 = values[end];
return Expression.OrElse(
Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body),
Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v2)).Body));
}
else
{
int mid = (start + end) / 2;
return Expression.OrElse(
CreateBalancedORExpression(values, keySelectorBody, start, mid),
CreateBalancedORExpression(values, keySelectorBody, mid + 1, end));
}
}
private static TKey[] Bucketize<TKey>(IEnumerable<TKey> values)
{
var distinctValues = new HashSet<TKey>(values).ToArray();
var originalLength = distinctValues.Length;
int bucket = (int)Math.Pow(2, Math.Ceiling(Math.Log(originalLength, 2)));
if (originalLength == bucket) return distinctValues;
var lastValue = distinctValues[originalLength - 1];
Array.Resize(ref distinctValues, bucket);
distinctValues.AsSpan().Slice(originalLength).Fill(lastValue);
return distinctValues;
}
}
}
@moredatapls
Copy link

Great work! Could you specify the license of this piece of code?

@ErikEJ
Copy link
Author

ErikEJ commented Jun 2, 2020

There is no license, and no license restrictions - use as you see fit!

@alb-xss
Copy link

alb-xss commented Dec 27, 2020

This is absolutely brilliant, the buckets idea is genius. Thank you for sharing your solution

@ErikEJ
Copy link
Author

ErikEJ commented Feb 17, 2021

@OculiViridi
Copy link

OculiViridi commented Jun 4, 2021

@ErikEJ Great piece of code! 😄

Is it possible to add this extension method in EF.Functions package? Here you can see also the Microsoft documentation.

I think it would be even more nice if it were possible to implement it as the EF.Functions.Like() method, to obtain ad usage like this:

var result = dbContext.MyEntities
    .Where(e => EF.Functions.In(values => e.Id))
    .ToList();

@h0wXD
Copy link

h0wXD commented Aug 13, 2021

I'm not sure what you're trying to achieve with the bucketize, but I think it can be much simpler, and only inlines once it exceeds the parameters:

// SQL Server max parameter count is 2100, keep space
private const int MaxSqlParameterCount = 2048;

// https://github.com/dotnet/efcore/issues/13617
public static IQueryable<TQuery> WhereIn<TKey, TQuery>(
    this IQueryable<TQuery> queryable,
    IEnumerable<TKey> values,
    Expression<Func<TQuery, TKey>> keySelector)
{
    if (values == null) throw new ArgumentNullException(nameof(values));
    if (keySelector == null) throw new ArgumentNullException(nameof(keySelector));
    if (!values.Any()) return queryable.Take(0);

    var distinctValues = values.Distinct().ToList();
    Expression expr = null;
    Expression<Func<TKey>> valueAsExpression = null;

    for (var i = 0; i < distinctValues.Count; i++)
    {
        var id = distinctValues[i];

        valueAsExpression = () => id;
        expr = expr == null
            ? Expression.Equal(keySelector.Body, valueAsExpression.Body)
            : i < MaxSqlParameterCount
                ? Expression.OrElse(expr, Expression.Equal(keySelector.Body, valueAsExpression.Body))
                : Expression.OrElse(expr, Expression.Equal(keySelector.Body, Expression.Constant(id)));
    }

    Bucketize(distinctValues.Count, () =>
        expr = Expression.OrElse(expr, Expression.Equal(keySelector.Body, valueAsExpression.Body)));

    var lambda = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters);

    return queryable.Where(lambda);
}

private static void Bucketize(int parameterCount, Action padExpression)
{
    if (parameterCount >= MaxSqlParameterCount)
    {
        return;
    }

    var bucket = 1;
    while (parameterCount > bucket)
    {
        bucket *= 2;
    }

    for (var i = parameterCount; i < bucket; i++)
    {
        padExpression();
    }
}

Updated to bucketize and limit the number of plans

@ErikEJ
Copy link
Author

ErikEJ commented Aug 13, 2021

@h0wXD Buckets are important in order to limit the number of plans, see my blog post here: https://erikej.github.io/efcore/sqlserver/2020/03/30/ef-core-cache-pollution.html

@RichardD2
Copy link

Couldn't you get rid of the PairWise function, and replace the loop:

while (predicates.Count > 1)
{
    predicates = PairWise(predicates).Select(p => Expression.OrElse(p.Item1, p.Item2)).ToList();
}

var body = predicates.Single();

with:

var body = predicates.Aggregate(Expression.OrElse);

@RichardD2
Copy link

RichardD2 commented Aug 24, 2021

Or, if you're worried about the deep nesting of conditions:

private readonly struct HalfList<T>
{
    private readonly IReadOnlyList<T> _list;
    private readonly int _startIndex;

    private HalfList(IReadOnlyList<T> list, int startIndex, int count)
    {
        _list = list ?? throw new ArgumentNullException(nameof(list));
        _startIndex = startIndex;
        Count = count;
    }

    public HalfList(IReadOnlyList<T> list) : this(list, 0, list.Count)
    {
    }

    public int Count { get; }
    
    public T Item => Count == 1 ? _list[_startIndex] : throw new InvalidOperationException();

    public (HalfList<T> left, HalfList<T> right) Split()
    {
        if (Count < 2) throw new InvalidOperationException();

        int pivot = Count >> 1;
        var left = new HalfList<T>(_list, _startIndex, pivot);
        var right = new HalfList<T>(_list, _startIndex + pivot, Count - pivot);
        return (left, right);
    }
}

private static Expression CombinePredicates(IReadOnlyList<Expression> parts, Func<Expression, Expression, Expression> fn)
{
    if (parts.Count == 0) throw new ArgumentException("At least one part is required.", nameof(parts));
    if (parts.Count == 1) return parts[0];
    
    var segment = new HalfList<Expression>(parts);
    return CombineCore(segment.Split(), fn);

    static Expression CombineCore((HalfList<Expression> left, HalfList<Expression> right) x, Func<Expression, Expression, Expression> fn)
    {
        var left = x.left.Count == 1 ? x.left.Item : CombineCore(x.left.Split(), fn);
        var right = x.right.Count == 1 ? x.right.Item : CombineCore(x.right.Split(), fn);
        return fn(left, right);
    }
}
var body = CombinePredicates(predicates, Expression.OrElse);

@joelmandell
Copy link

I had to use the IQueryableExtensions.In(), to support older SQL Server versions.
And had an issue when using a library that parses the table name from generated query (I am using multiple Interceptors). The issue is when IEnumerable<TKey> values is empty.

It has the check if(!values.Any()) and returns queryable.Take(0).
When that call is used EF Core generates an subquery, and the library that I use (EFCoreSecondLevelCacheInterceptor) fails trying to fetch table name from query.

I resorted to return queryable.Where(x => true) instead. And that seems to work after that. Posting this in case someone else has this problem.

Updated code of that extension method, according to that change:

public static IQueryable<TQuery> In<TKey, TQuery>(
            this IQueryable<TQuery> queryable,
            IEnumerable<TKey> values,
            Expression<Func<TQuery, TKey>> keySelector)
        {
            if (values == null)
            {
                throw new ArgumentNullException(nameof(values));
            }

            if (keySelector == null)
            {
                throw new ArgumentNullException(nameof(keySelector));
            }

            if (!values.Any())
            {
                //.Where instead of .Take(0), cause that seem to produce "funky" SQL.
                return queryable.Where(x => true);
            }

            var distinctValues = Bucketize(values);

            if (distinctValues.Length > 2048)
            {
                throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector));
            }

            var predicates = distinctValues
                .Select(v =>
                {
                    // Create an expression that captures the variable so EF can turn this into a parameterized SQL query
                    Expression<Func<TKey>> valueAsExpression = () => v;
                    return Expression.Equal(keySelector.Body, valueAsExpression.Body);
                })
                .ToList();

            while (predicates.Count > 1)
            {
                predicates = PairWise(predicates).Select(p => Expression.OrElse(p.Item1, p.Item2)).ToList();
            }

            var body = predicates.Single();

            var clause = Expression.Lambda<Func<TQuery, bool>>(body, keySelector.Parameters);

            return queryable.Where(clause);
        }

@yv989c
Copy link

yv989c commented Dec 29, 2021

@ErikEJ I believe you may find my project useful. It solves this problem in a flexible way. I have been using a similar strategy in my work with acceptable results, so I put some effort and made a generic version of it for everyone to use. Please, feel free to code review it if you have time. Ideas are welcome.

@yv989c
Copy link

yv989c commented Dec 29, 2021

Hey @joelmandell. Please take a look ☝️ too 🙂.

@fiseni
Copy link

fiseni commented Mar 12, 2023

Few updates, it performs a bit better.
We can do further improvements but we'll need to implement caching. So, it defies the purpose, too many cache layers.

public static class IQueryableExtensions
{
    public static IQueryable<TQuery> In<TKey, TQuery>(
        this IQueryable<TQuery> queryable,
        IEnumerable<TKey> values,
        Expression<Func<TQuery, TKey>> keySelector)
    {
        if (values == null)
        {
            throw new ArgumentNullException(nameof(values));
        }

        if (keySelector == null)
        {
            throw new ArgumentNullException(nameof(keySelector));
        }

        if (!values.Any())
        {
            return queryable.Take(0);
        }

        var distinctValues = Bucketize(values);

        if (distinctValues.Length > 2048)
        {
            throw new ArgumentException("Too many parameters for SQL Server, reduce the number of parameters", nameof(keySelector));
        }

        var expr = CreateBalancedORExpression(distinctValues, keySelector.Body, 0, distinctValues.Length - 1);

        var clause = Expression.Lambda<Func<TQuery, bool>>(expr, keySelector.Parameters);

        return queryable.Where(clause);
    }

    private static BinaryExpression CreateBalancedORExpression<TKey>(TKey[] values, Expression keySelectorBody, int start, int end)
    {
        if (start == end)
        {
            var v1 = values[start];
            return Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body);
        }
        else if (start + 1 == end)
        {
            var v1 = values[start];
            var v2 = values[end];

            return Expression.OrElse(
                Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v1)).Body),
                Expression.Equal(keySelectorBody, ((Expression<Func<TKey>>)(() => v2)).Body));
        }
        else
        {
            int mid = (start + end) / 2;
            return Expression.OrElse(
                CreateBalancedORExpression(values, keySelectorBody, start, mid),
                CreateBalancedORExpression(values, keySelectorBody, mid + 1, end));
        }
    }

    private static TKey[] Bucketize<TKey>(IEnumerable<TKey> values)
    {
        var distinctValues = new HashSet<TKey>(values).ToArray();
        var originalLength = distinctValues.Length;

        int bucket = (int)Math.Pow(2, Math.Ceiling(Math.Log(originalLength, 2)));

        if (originalLength == bucket) return distinctValues;

        var lastValue = distinctValues[originalLength - 1];
        Array.Resize(ref distinctValues, bucket);
        distinctValues.AsSpan().Slice(originalLength).Fill(lastValue);

        return distinctValues;
    }
}

@juliowh
Copy link

juliowh commented Feb 1, 2024

Hey @ErikEJ!
Sorry for reviving this gist, I'm having problems in undertanding why Bucketize is designed this way.
I would like to only know why it performs better creating bucket in sizes of 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024 and 2048 insted of something like 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, [...], 2048.
I tried but did not understand the concept of "plans".
Sorry if this is not the right place to question it and for the inconvenience.

@ErikEJ
Copy link
Author

ErikEJ commented Feb 2, 2024

@juliowh I think I answered that question already?

@juliowh
Copy link

juliowh commented Feb 2, 2024

I read that answer, but did not understand why a bucket size of 32 performs better than one of size 20.
A sql plan is always "created" in sizes of multiple of 2?

@ErikEJ
Copy link
Author

ErikEJ commented Feb 2, 2024

@juliowh It is the number of unique query plans that matter, and a factor 2 seemed like suitable bucket sizes

@clement911
Copy link

That's very cool @ErikEJ !

We use a lot of composite keys and I wonder if would be possible to create another overload that works with composite keys?

So I guess the signature might be something like this:

public static IQueryable<TQuery> In<TKey, TQuery>(
this IQueryable<TQuery> queryable,
IEnumerable<Tuple<TKey1, TKey2>> values,
Expression<Func<TQuery, Tuple<TKey1, TKey2>>> keySelector)

And we might use it like this:

[PrimaryKey(nameof(State), nameof(LicensePlate))]
internal class Car
{
    public string State { get; set; }
    public string LicensePlate { get; set; }

    public string Make { get; set; }
    public string Model { get; set; }
}

var keys = new[] { ("state1", "license1"), ("state2", "license2"), etc... }
var cars = context.Cars.In(keys, c => (c.State, c.LicensePlate));

It's a bit more complicated because the predicate needs to operate on two separate columns to generate something like this:

SELECT ...
FROM ...
WHERE (State = @pState1 AND LicensePlate = @pLicensePlate1)
OR    (State = @pState2 AND LicensePlate = @pLicensePlate2)
OR ...

@clement911
Copy link

@ErikEJ
Copy link
Author

ErikEJ commented Feb 23, 2024

@clement911 Feel free to do with this snippet whatever you want. It is just a sample.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment