Skip to content

Instantly share code, notes, and snippets.

@aradalvand
Last active July 31, 2023 19:42
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aradalvand/9b70e8bb455f5398affba610f2a25625 to your computer and use it in GitHub Desktop.
Save aradalvand/9b70e8bb455f5398affba610f2a25625 to your computer and use it in GitHub Desktop.
Solves the Hot Chocolate + EF Core + DTO overfetching problem for related entities
namespace Translator;
public static class QueryableExtensions
{
public static IQueryable<T> AsTranslatable<T>(this IQueryable<T> source)
{
if (source is TranslatableQuery<T> query)
return query;
return new TranslatableQueryProvider(source.Provider).CreateQuery<T>(source.Expression);
}
}
using System.Collections;
namespace Translator;
internal class TranslatableQuery<T> : IQueryable<T>, IOrderedQueryable<T>, IAsyncEnumerable<T>
{
private readonly TranslatableQueryProvider _provider;
private readonly Expression _expression;
public Type ElementType => typeof(T);
public Expression Expression => _expression;
public IQueryProvider Provider => _provider;
public TranslatableQuery(TranslatableQueryProvider provider, Expression expression)
{
_provider = provider;
_expression = expression;
}
public IEnumerator<T> GetEnumerator()
{
return _provider.ExecuteQuery<T>(_expression).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken ct)
{
return _provider.ExecuteQueryAsync<T>(_expression).GetAsyncEnumerator(ct);
}
}
using Microsoft.EntityFrameworkCore.Query;
namespace Translator;
internal class TranslatableQueryProvider : IAsyncQueryProvider
{
private readonly IQueryProvider _underlyingProvider;
public TranslatableQueryProvider(IQueryProvider underlyingQueryProvider)
{
_underlyingProvider = underlyingQueryProvider;
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new TranslatableQuery<TElement>(this, expression);
}
public IQueryable CreateQuery(Expression expression)
{
Type elementType = expression.Type.GetElementType()!;
return (IQueryable)Activator.CreateInstance(
typeof(TranslatableQuery<>).MakeGenericType(elementType),
new object[] { this, expression }
)!;
}
public IAsyncEnumerable<T> ExecuteQueryAsync<T>(Expression expression)
{
return _underlyingProvider.CreateQuery<T>(Visit(expression)).AsAsyncEnumerable();
}
public TResult ExecuteAsync<TResult>(Expression expression, CancellationToken ct = default)
{
if (_underlyingProvider is IAsyncQueryProvider asyncProvider)
return asyncProvider.ExecuteAsync<TResult>(Visit(expression), ct);
throw new Exception("The underlying query provider is not async.");
}
public IEnumerable<T> ExecuteQuery<T>(Expression expression)
{
return _underlyingProvider.CreateQuery<T>(Visit(expression)).AsEnumerable();
}
public TResult Execute<TResult>(Expression expression)
{
return _underlyingProvider.Execute<TResult>(Visit(expression));
}
public object? Execute(Expression expression)
{
return _underlyingProvider.Execute(Visit(expression));
}
private static Expression Visit(Expression expr)
{
return new TranslatableVisitor().Visit(expr);
}
}
using System.Reflection;
namespace Translator;
internal class TranslatableVisitor : ExpressionVisitor
{
protected override MemberAssignment VisitMemberAssignment(MemberAssignment node)
{
// Check if the member being initialized is a property:
if (node.Member is PropertyInfo property)
{
// Check both:
// - If the property's type itself a DTO class with an "Id" property:
if (property.PropertyType.IsSubclassOf(typeof(BaseDto)))
{
// Check if the expression being assigned to the property is a conditional expression (i.e. ternary expression)
if (node.Expression is ConditionalExpression conditionalExpression && conditionalExpression.Test is BinaryExpression condition)
{
// Check if the left part of the condition (bookDto.Author) is a property and its type is a DTO class:
bool isCheckingDtoProperty = condition.Left is MemberExpression memberExpression && memberExpression.Member is PropertyInfo proeprtyBeingChecked && proeprtyBeingChecked.PropertyType.IsSubclassOf(typeof(BaseDto));
// Check if the operator is an equality or inequality operator (== or !=):
bool isCheckingForEquality = condition.NodeType == ExpressionType.NotEqual || condition.NodeType == ExpressionType.Equal;
// Check if the right part of the condition (null) is a "null":
bool isNullCheck = condition.Right is ConstantExpression constantExpression && constantExpression.Value == null;
// Check if all three conditions above are true, in which case the condition of the conditional expression must look something like "[object].[PropertyOfBaseDtoType] == null":
if (isCheckingDtoProperty && isCheckingForEquality && isNullCheck)
{
// Create a new condition "[object].[BaseDto].Id == null" to replace the previous "[object].[BaseDto] == null" which would've caused Entity Framework to fetch the whole object:
var newLeft = Expression.Property(expression: condition.Left, property: typeof(BaseDto).GetProperty(nameof(BaseDto.Id))!);
var newCondition = Expression.NotEqual(left: Expression.Convert(newLeft, typeof(int?)), right: Expression.Constant(null));
// Create a new conditional expression in which the condition is our newly created one above:
// Note: Since our new condition is checking for inequality, we check if the old condition was also checking for inequality, in which case, we're good with the positions of the "ifTrue" and "ifFalse" expressions, otherwise, we want to swap them.
var newConditionalExpression = Expression.Condition(test: newCondition,
ifTrue: condition.NodeType == ExpressionType.NotEqual ? conditionalExpression.IfTrue : conditionalExpression.IfFalse,
ifFalse: condition.NodeType == ExpressionType.NotEqual ? conditionalExpression.IfFalse : conditionalExpression.IfTrue);
// Create a new member assignment in which the member being assigned is the same member as before of course, but the value being assigned to the member is our newly created conditional expression above.
var newMemberAssignment = Expression.Bind(node.Member, newConditionalExpression);
return base.VisitMemberAssignment(newMemberAssignment);
}
}
}
}
return base.VisitMemberAssignment(node);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment