Last active
December 10, 2023 02:38
-
-
Save leandromoh/92bec4eddbacb1b20f4a67cd5ac6a081 to your computer and use it in GitHub Desktop.
visitor that modify expression tree for safe null propagation! This code is now maintained in https://github.com/leandromoh/NullPropagationVisitor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Diagnostics; | |
using System.Linq.Expressions; | |
namespace ConsoleApp3 | |
{ | |
static class Program | |
{ | |
private static string Foo(string s) => s; | |
static void Main(string[] _) | |
{ | |
var visitor = new NullPropagationVisitor(recursive: true); | |
Test1(); | |
Test2(); | |
Test3(); | |
void Test1() | |
{ | |
Expression<Func<string, char?>> f = s => s == "foo" ? 'X' : Foo(s).Length.ToString()[0]; | |
var fBody = (Expression<Func<string, char?>>)visitor.Visit(f); | |
var fFunc = fBody.Compile(); | |
Debug.Assert(fFunc(null) == null); | |
Debug.Assert(fFunc("bar") == '3'); | |
Debug.Assert(fFunc("foo") == 'X'); | |
} | |
void Test2() | |
{ | |
Expression<Func<string, int>> y = s => s.Length; | |
var yBody = visitor.Visit(y.Body); | |
var yFunc = Expression.Lambda<Func<string, int?>>( | |
body: yBody, | |
parameters: y.Parameters) | |
.Compile(); | |
Debug.Assert(yFunc(null) == null); | |
Debug.Assert(yFunc("bar") == 3); | |
} | |
void Test3() | |
{ | |
Expression<Func<char?, string>> y = s => s.Value.ToString()[0].ToString(); | |
var yBody = visitor.Visit(y.Body); | |
var yFunc = Expression.Lambda<Func<char?, string>>( | |
body: yBody, | |
parameters: y.Parameters) | |
.Compile(); | |
Debug.Assert(yFunc(null) == null); | |
Debug.Assert(yFunc('A') == "A"); | |
} | |
} | |
public class NullPropagationVisitor : ExpressionVisitor | |
{ | |
private readonly bool _recursive; | |
public NullPropagationVisitor(bool recursive) | |
{ | |
_recursive = recursive; | |
} | |
protected override Expression VisitUnary(UnaryExpression propertyAccess) | |
{ | |
if (propertyAccess.Operand is MemberExpression mem) | |
return VisitMember(mem); | |
if (propertyAccess.Operand is MethodCallExpression met) | |
return VisitMethodCall(met); | |
if (propertyAccess.Operand is ConditionalExpression cond) | |
return Expression.Condition( | |
test: cond.Test, | |
ifTrue: MakeNullable(Visit(cond.IfTrue)), | |
ifFalse: MakeNullable(Visit(cond.IfFalse))); | |
return base.VisitUnary(propertyAccess); | |
} | |
protected override Expression VisitMember(MemberExpression propertyAccess) | |
{ | |
return Common(propertyAccess.Expression, propertyAccess); | |
} | |
protected override Expression VisitMethodCall(MethodCallExpression propertyAccess) | |
{ | |
if (propertyAccess.Object == null) | |
return base.VisitMethodCall(propertyAccess); | |
return Common(propertyAccess.Object, propertyAccess); | |
} | |
private BlockExpression Common(Expression instance, Expression propertyAccess) | |
{ | |
var safe = _recursive ? base.Visit(instance) : instance; | |
var caller = Expression.Variable(safe.Type, "caller"); | |
var assign = Expression.Assign(caller, safe); | |
var acess = MakeNullable(new ExpressionReplacer(instance, | |
IsNullableStruct(instance) ? caller : RemoveNullable(caller)).Visit(propertyAccess)); | |
var ternary = Expression.Condition( | |
test: Expression.Equal(caller, Expression.Constant(null)), | |
ifTrue: Expression.Constant(null, acess.Type), | |
ifFalse: acess); | |
return Expression.Block( | |
type: acess.Type, | |
variables: new[] | |
{ | |
caller, | |
}, | |
expressions: new Expression[] | |
{ | |
assign, | |
ternary, | |
}); | |
} | |
private static Expression MakeNullable(Expression ex) | |
{ | |
if (IsNullable(ex)) | |
return ex; | |
return Expression.Convert(ex, typeof(Nullable<>).MakeGenericType(ex.Type)); | |
} | |
private static bool IsNullable(Expression ex) | |
{ | |
return !ex.Type.IsValueType || (Nullable.GetUnderlyingType(ex.Type) != null); | |
} | |
private static bool IsNullableStruct(Expression ex) | |
{ | |
return ex.Type.IsValueType && (Nullable.GetUnderlyingType(ex.Type) != null); | |
} | |
private static Expression RemoveNullable(Expression ex) | |
{ | |
if (IsNullableStruct(ex)) | |
return Expression.Convert(ex, ex.Type.GenericTypeArguments[0]); | |
return ex; | |
} | |
private class ExpressionReplacer : ExpressionVisitor | |
{ | |
private readonly Expression _oldEx; | |
private readonly Expression _newEx; | |
internal ExpressionReplacer(Expression oldEx, Expression newEx) | |
{ | |
_oldEx = oldEx; | |
_newEx = newEx; | |
} | |
public override Expression Visit(Expression node) | |
{ | |
if (node == _oldEx) | |
return _newEx; | |
return base.Visit(node); | |
} | |
} | |
} | |
} | |
} |
@Liero can you share the code version that worked for you, please?
I have hacked something that works in my case, but it's not universal.
I have added VisitLambda
and MakeNullableType
methods:
However, I think the problem lies deeper. For example the original visitor is removing Convert expressions and I have to add those manually VisitLambda.
public class NullPropagationVisitor : ExpressionVisitor
{
private readonly bool _recursive;
public NullPropagationVisitor(bool recursive)
{
_recursive = recursive;
}
protected override Expression VisitLambda<T>(Expression<T> node)
{
var body = Visit(node.Body);
var expectedReturnType = MakeNullableType(node.ReturnType);
if (body.Type != expectedReturnType)
{
body = Expression.Convert(body, expectedReturnType);
}
return Expression.Lambda(body, node.Parameters);
}
protected override Expression VisitUnary(UnaryExpression propertyAccess)
{
if (propertyAccess.Operand is MemberExpression mem)
return VisitMember(mem);
if (propertyAccess.Operand is MethodCallExpression met)
return VisitMethodCall(met);
if (propertyAccess.Operand is ConditionalExpression cond)
return Expression.Condition(
test: cond.Test,
ifTrue: MakeNullable(Visit(cond.IfTrue)),
ifFalse: MakeNullable(Visit(cond.IfFalse)));
return base.VisitUnary(propertyAccess);
}
protected override Expression VisitMember(MemberExpression propertyAccess)
{
if (propertyAccess.Expression == null)
{
throw new ArgumentException($"The parameter member {propertyAccess}.{propertyAccess.Expression} cannot be null", nameof(propertyAccess))
}
return Common(propertyAccess.Expression, propertyAccess);
}
protected override Expression VisitMethodCall(MethodCallExpression propertyAccess)
{
if (propertyAccess.Object == null)
return base.VisitMethodCall(propertyAccess);
return Common(propertyAccess.Object, propertyAccess);
}
private Expression Common(Expression instance, Expression propertyAccess)
{
var safe = _recursive ? base.Visit(instance) : instance;
var caller = Expression.Variable(safe.Type, "caller");
var assign = Expression.Assign(caller, safe);
var acess = MakeNullable(new ExpressionReplacer(instance,
IsNullableStruct(instance) ? caller : RemoveNullable(caller)).Visit(propertyAccess)!);
var ternary = Expression.Condition(
test: Expression.Equal(caller, Expression.Constant(null)),
ifTrue: Expression.Constant(null, acess.Type),
ifFalse: acess);
return Expression.Block(
type: acess.Type,
variables: new[]
{
caller,
},
expressions: new Expression[]
{
assign,
ternary,
});
}
public static Type MakeNullableType(Type type)
{
return type.IsValueType && Nullable.GetUnderlyingType(type) == null
? typeof(Nullable<>).MakeGenericType(type)
: type;
}
public static Expression MakeNullable(Expression ex)
{
if (IsNullable(ex))
return ex;
return Expression.Convert(ex, typeof(Nullable<>).MakeGenericType(ex.Type));
}
private static bool IsNullable(Expression ex)
{
return !ex.Type.IsValueType || Nullable.GetUnderlyingType(ex.Type) != null;
}
private static bool IsNullableStruct(Expression ex)
{
return ex.Type.IsValueType && Nullable.GetUnderlyingType(ex.Type) != null;
}
private static Expression RemoveNullable(Expression ex)
{
if (IsNullableStruct(ex))
return Expression.Convert(ex, ex.Type.GenericTypeArguments[0]);
return ex;
}
private class ExpressionReplacer : ExpressionVisitor
{
private readonly Expression _oldEx;
private readonly Expression _newEx;
internal ExpressionReplacer(Expression oldEx, Expression newEx)
{
_oldEx = oldEx;
_newEx = newEx;
}
public override Expression? Visit(Expression? node)
{
if (node == _oldEx)
return _newEx;
return base.Visit(node);
}
}
}
and the tests:
#nullable enable
#pragma warning disable CS8602 // Dereference of a possibly null reference.
[TestClass]
public class NullPropagationVisitorTests
{
record Foo(string? Name, Foo? Child = null);
[TestMethod]
public void NullPropagationVisitor_AddsNullChecks()
{
Expression<Func<Foo, char>> expression = foo => foo.Child.Name.ToString()[0];
//sut
var visitor = new NullPropagationVisitor(true);
//action
var expressionWithNullChecks = (Expression<Func<Foo, char?>>)visitor.Visit(expression);
var funcGetChildFirstChar = expressionWithNullChecks.Compile();
Assert.IsNull(funcGetChildFirstChar(new Foo("parent")));
Assert.IsNull(funcGetChildFirstChar(new Foo("parent", new Foo(null))));
Assert.AreEqual('c', funcGetChildFirstChar(new Foo("parent", new Foo("child"))));
}
[TestMethod]
public void NullPropagationVisitor_LambdasHasCorrectSignature()
{
Expression<Func<Foo, object>> expression = foo => foo.Child.Name.ToString()[0];
Type expectedType = typeof(Func<Foo, object?>);
//sut
var visitor = new NullPropagationVisitor(true);
//action
LambdaExpression expressionWithNullChecks = (LambdaExpression)visitor.Visit(expression);
Assert.AreEqual(expectedType, expressionWithNullChecks.Type);
}
}
@Liero thanks for sharing! I will take a look. Maybe it is time to create a full repository/nuget for this code and receive PR/improvements from community.
@Liero the problem was deeply solved and shared in this repository as well in nuget
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This throws exception, when the original LambdaExpression has non nullable return type. E.g.
EDIT:
Resolved using