Skip to content

Instantly share code, notes, and snippets.

@kroymann
Last active May 10, 2024 13:07
Show Gist options
  • Save kroymann/e57b3b4f30e6056a3465dbf118e5f13d to your computer and use it in GitHub Desktop.
Save kroymann/e57b3b4f30e6056a3465dbf118e5f13d to your computer and use it in GitHub Desktop.
EntityFramework extension that creates re-usable, parameterized queries that are equivalent to .Where(entity => enumeration.Contains(entity.property))
using AutoMapper;
using System;
using System.Collections.Generic;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
namespace RecNetCore.Utils
{
public static class BatchedQueryableIQueryableExtensions
{
/// <summary>
/// Generates a new IBatchedQueryable (which is a collection of queryables)that in aggregate mimic the effect of
/// <code>queryable.Where(obj => values.Contains(obj.Id))</code>. The values are broken up into chunks of the
/// specified batch size, and then a queryable is generated for each chunk. For each chunk a where clause is
/// generated using a sequency of equality comparisons OR'd together, which allows the values to be passed as
/// SQL parameters rather than being embedded directly in the SQL. This enables EntityFramework to cache the
/// Linq-to-SQL mapping (which is very expensive to generate), as well as enabling the SQL server to cache the
/// SQL execution plan (which is also very expensive to generate). Last but not least, because these are now
/// parameterized queries, they will now show up as the same query when viewing query analytics rather than each
/// unique combination of values appearing as a distinct query.
///
/// <example>Here is an example of how this works. The following call:
/// <code>
/// query.BatchedWhereKeyInValues(q => q.Id, values: {1,5,7,8,9}, batchSize: 3)
/// </code>
///
/// would produce the following two IQueryables:
/// <code>
/// query.Where(q => q.Id == 1 || q.Id == 5 || q.Id == 7)
/// query.Where(q => q.Id == 8 || q.Id == 9 || q.Id == 9)
/// </code>
///
/// which in turn would produce a parameterized SQL query that is used for both batches that looks like this:
/// <code>
/// WHERE Id IN (@p__linq__0, @p__linq__1, @p__linq__2)
/// </code>
/// </example>
/// </summary>
public static IBatchedQueryable<TQuery> BatchedWhereKeyInValues<TQuery, TKey>(this IQueryable<TQuery> queryable, Expression<Func<TQuery, TKey>> keySelector, IEnumerable<TKey> values, int batchSize = 10)
{
var queries = BatchedQueryable<TQuery>.GenerateQueries(queryable, keySelector, values, batchSize);
return new BatchedQueryable<TQuery>(queries);
}
/// <summary>
/// Simple helper that converts a queryable into a batched queryable containing the single query.
/// </summary>
public static IBatchedQueryable<TQuery> AsBatchedQueryable<TQuery>(this IQueryable<TQuery> queryable)
{
return (queryable != null)
? new BatchedQueryable<TQuery>(new List<IQueryable<TQuery>>() { queryable })
: null;
}
}
public interface IBatchedQueryable<TQuery>
{
#region Query Builder APIs
IBatchedQueryable<TQuery> Where(Expression<Func<TQuery, bool>> predicate);
IBatchedQueryable<TResult> Select<TResult>(Expression<Func<TQuery, TResult>> selector);
IBatchedQueryable<TResult> SelectMany<TResult>(Expression<Func<TQuery, IEnumerable<TResult>>> selector);
IBatchedQueryable<TQuery> Include<TProperty>(Expression<Func<TQuery, TProperty>> path);
IBatchedQueryable<TQuery> BatchedWhereKeyInValues<TKey>(Expression<Func<TQuery, TKey>> keySelector, IEnumerable<TKey> values, int batchSize = 10);
#endregion
#region Materialize Result APIs
List<TQuery> ToList();
TQuery First();
TQuery FirstOrDefault();
TQuery Single();
TQuery SingleOrDefault();
Task<List<TQuery>> ToListAsync();
Task<TQuery> FirstAsync();
Task<TQuery> FirstOrDefaultAsync();
Task<TQuery> SingleAsync();
Task<TQuery> SingleOrDefaultAsync();
List<TResult> ProjectToList<TResult>();
TResult ProjectToFirst<TResult>();
TResult ProjectToFirstOrDefault<TResult>();
TResult ProjectToSingle<TResult>();
TResult ProjectToSingleOrDefault<TResult>();
Task<List<TResult>> ProjectToListAsync<TResult>();
Task<TResult> ProjectToFirstAsync<TResult>();
Task<TResult> ProjectToFirstOrDefaultAsync<TResult>();
Task<TResult> ProjectToSingleAsync<TResult>();
Task<TResult> ProjectToSingleOrDefaultAsync<TResult>();
bool Any();
Task<bool> AnyAsync();
Dictionary<TKey, TElement> ToDictionary<TKey, TElement>(Func<TQuery, TKey> keySelector, Func<TQuery, TElement> elementSelector);
Task<Dictionary<TKey, TElement>> ToDictionaryAsync<TKey, TElement>(Func<TQuery, TKey> keySelector, Func<TQuery, TElement> elementSelector);
#endregion
}
public class BatchedQueryable<TQuery> : IBatchedQueryable<TQuery>
{
#region Fields
private IEnumerable<IQueryable<TQuery>> queries;
#endregion
#region Constructor
public BatchedQueryable(IEnumerable<IQueryable<TQuery>> queries)
{
this.queries = queries;
}
#endregion
#region Query Builder APIs
public IBatchedQueryable<TQuery> Where(Expression<Func<TQuery, bool>> predicate)
{
return new BatchedQueryable<TQuery>(queries.Select(q => q.Where(predicate)));
}
public IBatchedQueryable<TResult> Select<TResult>(Expression<Func<TQuery, TResult>> selector)
{
return new BatchedQueryable<TResult>(queries.Select(q => q.Select(selector)));
}
public IBatchedQueryable<TResult> SelectMany<TResult>(Expression<Func<TQuery, IEnumerable<TResult>>> selector)
{
return new BatchedQueryable<TResult>(queries.Select(q => q.SelectMany(selector)));
}
public IBatchedQueryable<TQuery> Include<TProperty>(Expression<Func<TQuery, TProperty>> path)
{
return new BatchedQueryable<TQuery>(queries.Select(q => q.Include(path)));
}
public IBatchedQueryable<TQuery> BatchedWhereKeyInValues<TKey>(Expression<Func<TQuery, TKey>> keySelector, IEnumerable<TKey> values, int batchSize = 10)
{
return new BatchedQueryable<TQuery>(queries.SelectMany(q => GenerateQueries(q, keySelector, values, batchSize)));
}
#endregion
#region Materialize Result APIs
public List<TQuery> ToList() => ToList(Enumerable.ToList);
public TQuery First() => First(Enumerable.ToList);
public TQuery FirstOrDefault() => FirstOrDefault(Enumerable.ToList);
public TQuery Single() => Single(Enumerable.ToList);
public TQuery SingleOrDefault() => SingleOrDefault(Enumerable.ToList);
public Task<List<TQuery>> ToListAsync() => ToListAsync(QueryableExtensions.ToListAsync);
public Task<TQuery> FirstAsync() => FirstAsync(QueryableExtensions.ToListAsync);
public Task<TQuery> FirstOrDefaultAsync() => FirstOrDefaultAsync(QueryableExtensions.ToListAsync);
public Task<TQuery> SingleAsync() => SingleAsync(QueryableExtensions.ToListAsync);
public Task<TQuery> SingleOrDefaultAsync() => SingleOrDefaultAsync(QueryableExtensions.ToListAsync);
public List<TResult> ProjectToList<TResult>() => ToList(EntityFrameworkExtensions.ProjectToList<TResult>);
public TResult ProjectToFirst<TResult>() => First(EntityFrameworkExtensions.ProjectToList<TResult>);
public TResult ProjectToFirstOrDefault<TResult>() => FirstOrDefault(EntityFrameworkExtensions.ProjectToList<TResult>);
public TResult ProjectToSingle<TResult>() => Single(EntityFrameworkExtensions.ProjectToList<TResult>);
public TResult ProjectToSingleOrDefault<TResult>() => SingleOrDefault(EntityFrameworkExtensions.ProjectToList<TResult>);
public Task<List<TResult>> ProjectToListAsync<TResult>() => ToListAsync(EntityFrameworkExtensions.ProjectToListAsync<TResult>);
public Task<TResult> ProjectToFirstAsync<TResult>() => FirstAsync(EntityFrameworkExtensions.ProjectToListAsync<TResult>);
public Task<TResult> ProjectToFirstOrDefaultAsync<TResult>() => FirstOrDefaultAsync(EntityFrameworkExtensions.ProjectToListAsync<TResult>);
public Task<TResult> ProjectToSingleAsync<TResult>() => SingleAsync(EntityFrameworkExtensions.ProjectToListAsync<TResult>);
public Task<TResult> ProjectToSingleOrDefaultAsync<TResult>() => SingleOrDefaultAsync(EntityFrameworkExtensions.ProjectToListAsync<TResult>);
public bool Any()
{
foreach (var query in queries)
{
if (query.Any())
{
return true;
}
}
return false;
}
public async Task<bool> AnyAsync()
{
foreach (var query in queries)
{
if (await query.AnyAsync())
{
return true;
}
}
return false;
}
public Dictionary<TKey, TElement> ToDictionary<TKey, TElement>(Func<TQuery, TKey> keySelector, Func<TQuery, TElement> elementSelector)
{
var list = ToList();
return list.ToDictionary(keySelector, elementSelector);
}
public async Task<Dictionary<TKey, TElement>> ToDictionaryAsync<TKey, TElement>(Func<TQuery, TKey> keySelector, Func<TQuery, TElement> elementSelector)
{
var list = await ToListAsync();
return list.ToDictionary(keySelector, elementSelector);
}
#endregion
#region Internal Helpers
#region Synchronous
private List<TResult> ToList<TResult>(Func<IQueryable<TQuery>, List<TResult>> toListMethod)
{
List<TResult> results = new List<TResult>();
foreach (var query in queries)
{
results.AddRange(toListMethod(query));
}
return results;
}
private TResult First<TResult>(Func<IQueryable<TQuery>, List<TResult>> toListMethod)
{
Tuple<bool, TResult> result = FirstHelper(toListMethod);
if (result.Item1)
{
return result.Item2;
}
else
{
throw new InvalidOperationException("The source sequence is empty!");
}
}
private TResult FirstOrDefault<TResult>(Func<IQueryable<TQuery>, List<TResult>> toListMethod)
{
Tuple<bool, TResult> result = FirstHelper(toListMethod);
return result.Item2;
}
private Tuple<bool, TResult> FirstHelper<TResult>(Func<IQueryable<TQuery>, List<TResult>> toListMethod)
{
foreach (var query in queries)
{
var queryResults = toListMethod(query.Take(1));
if (queryResults.Count == 1)
{
return new Tuple<bool, TResult>(true, queryResults[0]);
}
}
return new Tuple<bool, TResult>(false, default(TResult));
}
private TResult Single<TResult>(Func<IQueryable<TQuery>, List<TResult>> toListMethod)
{
Tuple<bool, TResult> result = SingleHelper(toListMethod);
if (result.Item1)
{
return result.Item2;
}
else
{
throw new InvalidOperationException("The source sequence is empty!");
}
}
private TResult SingleOrDefault<TResult>(Func<IQueryable<TQuery>, List<TResult>> toListMethod)
{
Tuple<bool, TResult> result = SingleHelper(toListMethod);
return result.Item2;
}
private Tuple<bool, TResult> SingleHelper<TResult>(Func<IQueryable<TQuery>, List<TResult>> toListMethod)
{
bool foundResultAlready = false;
TResult result = default(TResult);
foreach (var query in queries)
{
var queryResults = toListMethod(query.Take(2));
if (queryResults.Count > 0)
{
if (foundResultAlready || queryResults.Count > 1)
{
throw new InvalidOperationException("Sequence contains more than one element!");
}
foundResultAlready = true;
result = queryResults[0];
}
}
return new Tuple<bool, TResult>(foundResultAlready, result);
}
#endregion
#region Asynchronous
private async Task<List<TResult>> ToListAsync<TResult>(Func<IQueryable<TQuery>, Task<List<TResult>>> toListMethod)
{
List<TResult> results = new List<TResult>();
foreach (var query in queries)
{
results.AddRange(await toListMethod(query));
}
return results;
}
private async Task<TResult> FirstAsync<TResult>(Func<IQueryable<TQuery>, Task<List<TResult>>> toListMethod)
{
Tuple<bool, TResult> result = await FirstAsyncHelper(toListMethod);
if (result.Item1)
{
return result.Item2;
}
else
{
throw new InvalidOperationException("The source sequence is empty!");
}
}
private async Task<TResult> FirstOrDefaultAsync<TResult>(Func<IQueryable<TQuery>, Task<List<TResult>>> toListMethod)
{
Tuple<bool, TResult> result = await FirstAsyncHelper(toListMethod);
return result.Item2;
}
private async Task<Tuple<bool, TResult>> FirstAsyncHelper<TResult>(Func<IQueryable<TQuery>, Task<List<TResult>>> toListMethod)
{
foreach (var query in queries)
{
var queryResults = await toListMethod(query.Take(1));
if (queryResults.Count == 1)
{
return new Tuple<bool, TResult>(true, queryResults[0]);
}
}
return new Tuple<bool, TResult>(false, default(TResult));
}
private async Task<TResult> SingleAsync<TResult>(Func<IQueryable<TQuery>, Task<List<TResult>>> toListMethod)
{
Tuple<bool, TResult> result = await SingleAsyncHelper(toListMethod);
if (result.Item1)
{
return result.Item2;
}
else
{
throw new InvalidOperationException("The source sequence is empty!");
}
}
private async Task<TResult> SingleOrDefaultAsync<TResult>(Func<IQueryable<TQuery>, Task<List<TResult>>> toListMethod)
{
Tuple<bool, TResult> result = await SingleAsyncHelper(toListMethod);
return result.Item2;
}
private async Task<Tuple<bool, TResult>> SingleAsyncHelper<TResult>(Func<IQueryable<TQuery>, Task<List<TResult>>> toListMethod)
{
bool foundResultAlready = false;
TResult result = default(TResult);
foreach (var query in queries)
{
var queryResults = await toListMethod(query.Take(2));
if (queryResults.Count > 0)
{
if (foundResultAlready || queryResults.Count > 1)
{
throw new InvalidOperationException("Sequence contains more than one element!");
}
foundResultAlready = true;
result = queryResults[0];
}
}
return new Tuple<bool, TResult>(foundResultAlready, result);
}
#endregion
internal static IEnumerable<IQueryable<TQuery>> GenerateQueries<TKey>(IQueryable<TQuery> queryable, Expression<Func<TQuery, TKey>> keySelector, IEnumerable<TKey> values, int batchSize)
{
List<TKey> distinctValues = values.Distinct().ToList();
int lastBatchSize = distinctValues.Count % batchSize;
if (lastBatchSize != 0)
{
distinctValues.AddRange(Enumerable.Repeat(distinctValues.Last(), batchSize - lastBatchSize));
}
int count = distinctValues.Count;
for (int i = 0; i < count; i += batchSize)
{
var body = distinctValues
.SkipTake(i, batchSize)
.Select(v =>
{
// Create an expression that captures the variable so EF can turn this into a parameterized SQL query
Expression<Func<TKey>> valueAsExpression = () => v;
return Expression.Equal(keySelector.Body, valueAsExpression.Body);
})
.Aggregate((a, b) => Expression.OrElse(a, b));
var whereClause = Expression.Lambda<Func<TQuery, bool>>(body, keySelector.Parameters);
yield return queryable.Where(whereClause);
}
}
#endregion
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment