Skip to content

Instantly share code, notes, and snippets.

@haacked
Last active March 29, 2024 14:16
Show Gist options
  • Star 26 You must be signed in to star a gist
  • Fork 7 You must be signed in to fork a gist
  • Save haacked/febe9e88354fb2f4a4eb11ba88d64c24 to your computer and use it in GitHub Desktop.
Save haacked/febe9e88354fb2f4a4eb11ba88d64c24 to your computer and use it in GitHub Desktop.
Example of applying an EF Core global query filter on all entity types that implement an interface
/*
Copyright Phil Haack
Licensed under the MIT license - https://github.com/haacked/CodeHaacks/blob/main/LICENSE.
*/
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query;
public static class ModelBuilderExtensions
{
static readonly MethodInfo SetQueryFilterMethod = typeof(ModelBuilderExtensions)
.GetMethods(BindingFlags.NonPublic | BindingFlags.Static)
.Single(t => t.IsGenericMethod && t.Name == nameof(SetQueryFilter));
public static void SetQueryFilterOnAllEntities<TEntityInterface>(
this ModelBuilder builder,
Expression<Func<TEntityInterface, bool>> filterExpression)
{
foreach (var type in builder.Model.GetEntityTypes()
.Where(t => t.BaseType == null)
.Select(t => t.ClrType)
.Where(t => typeof(TEntityInterface).IsAssignableFrom(t)))
{
builder.SetEntityQueryFilter(
type,
filterExpression);
}
}
static void SetEntityQueryFilter<TEntityInterface>(
this ModelBuilder builder,
Type entityType,
Expression<Func<TEntityInterface, bool>> filterExpression)
{
SetQueryFilterMethod
.MakeGenericMethod(entityType, typeof(TEntityInterface))
.Invoke(null, new object[] { builder, filterExpression });
}
static void SetQueryFilter<TEntity, TEntityInterface>(
this ModelBuilder builder,
Expression<Func<TEntityInterface, bool>> filterExpression)
where TEntityInterface : class
where TEntity : class, TEntityInterface
{
var concreteExpression = filterExpression
.Convert<TEntityInterface, TEntity>();
builder.Entity<TEntity>()
.AppendQueryFilter(concreteExpression);
}
// CREDIT: This comment by magiak on GitHub https://github.com/dotnet/efcore/issues/10275#issuecomment-785916356
static void AppendQueryFilter<T>(this EntityTypeBuilder entityTypeBuilder, Expression<Func<T, bool>> expression)
where T : class
{
var parameterType = Expression.Parameter(entityTypeBuilder.Metadata.ClrType);
var expressionFilter = ReplacingExpressionVisitor.Replace(
expression.Parameters.Single(), parameterType, expression.Body);
if (entityTypeBuilder.Metadata.GetQueryFilter() != null)
{
var currentQueryFilter = entityTypeBuilder.Metadata.GetQueryFilter();
var currentExpressionFilter = ReplacingExpressionVisitor.Replace(
currentQueryFilter.Parameters.Single(), parameterType, currentQueryFilter.Body);
expressionFilter = Expression.AndAlso(currentExpressionFilter, expressionFilter);
}
var lambdaExpression = Expression.Lambda(expressionFilter, parameterType);
entityTypeBuilder.HasQueryFilter(lambdaExpression);
}
}
public static class ExpressionExtensions
{
// This magic is courtesy of this StackOverflow post.
// https://stackoverflow.com/questions/38316519/replace-parameter-type-in-lambda-expression
// I made some tweaks to adapt it to our needs - @haacked
public static Expression<Func<TTarget, bool>> Convert<TSource, TTarget>(
this Expression<Func<TSource, bool>> root)
{
var visitor = new ParameterTypeVisitor<TSource, TTarget>();
return (Expression<Func<TTarget, bool>>)visitor.Visit(root);
}
class ParameterTypeVisitor<TSource, TTarget> : ExpressionVisitor
{
private ReadOnlyCollection<ParameterExpression> _parameters;
protected override Expression VisitParameter(ParameterExpression node)
{
return _parameters?.FirstOrDefault(p => p.Name == node.Name)
?? (node.Type == typeof(TSource) ? Expression.Parameter(typeof(TTarget), node.Name) : node);
}
protected override Expression VisitLambda<T>(Expression<T> node)
{
_parameters = VisitAndConvert(node.Parameters, "VisitLambda");
return Expression.Lambda(Visit(node.Body), _parameters);
}
}
}
@stzoran1
Copy link

stzoran1 commented Aug 8, 2020

Thank you very much. I updated license header ;)

@peteralbanese
Copy link

superhuman skills...thanks this saved my behind.

@haacked
Copy link
Author

haacked commented Mar 23, 2021

I've updated the code for EF Core 5.0.2 and above. It no longer relies on internal interfaces. If you want the version for EF Core 3, you'll have to look at the previous revision for this gist.

@bbilginn
Copy link

Hi, I have used this method before, but I realized it doesn't work for me anymore. I have searched which commit broked by, but I didn't found it. And I start used again HasQueryFilter with inheritance. All members of the BaseEntityMap are affected. There are screenshots here;

image

image

image

image

@sarowa36
Copy link

sarowa36 commented Nov 29, 2022

Thats more simple and working.

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Query;

public static class ModelBuilderExtensions
{
    public static void ApplyGlobalFilters<TInterface>(this ModelBuilder modelBuilder, Expression<Func<TInterface, bool>> expression)
    {
        var entities = modelBuilder.Model
            .GetEntityTypes()
            .Where(t => t.BaseType == null)
            .Select(t => t.ClrType)
            .Where(t => typeof(TInterface).IsAssignableFrom(t));
        foreach (var entity in entities)
        {
            var newParam = Expression.Parameter(entity);
            var newbody = ReplacingExpressionVisitor.Replace(expression.Parameters.Single(), newParam, expression.Body);
            modelBuilder.Entity(entity).HasQueryFilter(Expression.Lambda(newbody, newParam));
        }
    }
    public static void ApplyGlobalInclude<TInterface, TProperty>(this ModelBuilder modelBuilder, Expression<Func<TInterface, TProperty>> expression) where TProperty : class
    {
        var entities = modelBuilder.Model
            .GetEntityTypes()
            .Where(t => t.BaseType == null)
            .Select(t => t.ClrType)
            .Where(t => typeof(TInterface).IsAssignableFrom(t));
        foreach (var entity in entities)
        {
            modelBuilder.Entity(entity).Navigation(expression.ReturnType.Name).AutoInclude();
        }
    }
}

@Grauenwolf
Copy link

Use three quotes ``` instead of one ` to get formatting.

@sarowa36
Copy link

sarowa36 commented Dec 4, 2022

@Grauenwolf thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment