Last active
December 10, 2023 02:38
-
-
Save leandromoh/92bec4eddbacb1b20f4a67cd5ac6a081 to your computer and use it in GitHub Desktop.
visitor that modify expression tree for safe null propagation! This code is now maintained in https://github.com/leandromoh/NullPropagationVisitor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Diagnostics; | |
using System.Linq.Expressions; | |
namespace ConsoleApp3 | |
{ | |
static class Program | |
{ | |
private static string Foo(string s) => s; | |
static void Main(string[] _) | |
{ | |
var visitor = new NullPropagationVisitor(recursive: true); | |
Test1(); | |
Test2(); | |
Test3(); | |
void Test1() | |
{ | |
Expression<Func<string, char?>> f = s => s == "foo" ? 'X' : Foo(s).Length.ToString()[0]; | |
var fBody = (Expression<Func<string, char?>>)visitor.Visit(f); | |
var fFunc = fBody.Compile(); | |
Debug.Assert(fFunc(null) == null); | |
Debug.Assert(fFunc("bar") == '3'); | |
Debug.Assert(fFunc("foo") == 'X'); | |
} | |
void Test2() | |
{ | |
Expression<Func<string, int>> y = s => s.Length; | |
var yBody = visitor.Visit(y.Body); | |
var yFunc = Expression.Lambda<Func<string, int?>>( | |
body: yBody, | |
parameters: y.Parameters) | |
.Compile(); | |
Debug.Assert(yFunc(null) == null); | |
Debug.Assert(yFunc("bar") == 3); | |
} | |
void Test3() | |
{ | |
Expression<Func<char?, string>> y = s => s.Value.ToString()[0].ToString(); | |
var yBody = visitor.Visit(y.Body); | |
var yFunc = Expression.Lambda<Func<char?, string>>( | |
body: yBody, | |
parameters: y.Parameters) | |
.Compile(); | |
Debug.Assert(yFunc(null) == null); | |
Debug.Assert(yFunc('A') == "A"); | |
} | |
} | |
public class NullPropagationVisitor : ExpressionVisitor | |
{ | |
private readonly bool _recursive; | |
public NullPropagationVisitor(bool recursive) | |
{ | |
_recursive = recursive; | |
} | |
protected override Expression VisitUnary(UnaryExpression propertyAccess) | |
{ | |
if (propertyAccess.Operand is MemberExpression mem) | |
return VisitMember(mem); | |
if (propertyAccess.Operand is MethodCallExpression met) | |
return VisitMethodCall(met); | |
if (propertyAccess.Operand is ConditionalExpression cond) | |
return Expression.Condition( | |
test: cond.Test, | |
ifTrue: MakeNullable(Visit(cond.IfTrue)), | |
ifFalse: MakeNullable(Visit(cond.IfFalse))); | |
return base.VisitUnary(propertyAccess); | |
} | |
protected override Expression VisitMember(MemberExpression propertyAccess) | |
{ | |
return Common(propertyAccess.Expression, propertyAccess); | |
} | |
protected override Expression VisitMethodCall(MethodCallExpression propertyAccess) | |
{ | |
if (propertyAccess.Object == null) | |
return base.VisitMethodCall(propertyAccess); | |
return Common(propertyAccess.Object, propertyAccess); | |
} | |
private BlockExpression Common(Expression instance, Expression propertyAccess) | |
{ | |
var safe = _recursive ? base.Visit(instance) : instance; | |
var caller = Expression.Variable(safe.Type, "caller"); | |
var assign = Expression.Assign(caller, safe); | |
var acess = MakeNullable(new ExpressionReplacer(instance, | |
IsNullableStruct(instance) ? caller : RemoveNullable(caller)).Visit(propertyAccess)); | |
var ternary = Expression.Condition( | |
test: Expression.Equal(caller, Expression.Constant(null)), | |
ifTrue: Expression.Constant(null, acess.Type), | |
ifFalse: acess); | |
return Expression.Block( | |
type: acess.Type, | |
variables: new[] | |
{ | |
caller, | |
}, | |
expressions: new Expression[] | |
{ | |
assign, | |
ternary, | |
}); | |
} | |
private static Expression MakeNullable(Expression ex) | |
{ | |
if (IsNullable(ex)) | |
return ex; | |
return Expression.Convert(ex, typeof(Nullable<>).MakeGenericType(ex.Type)); | |
} | |
private static bool IsNullable(Expression ex) | |
{ | |
return !ex.Type.IsValueType || (Nullable.GetUnderlyingType(ex.Type) != null); | |
} | |
private static bool IsNullableStruct(Expression ex) | |
{ | |
return ex.Type.IsValueType && (Nullable.GetUnderlyingType(ex.Type) != null); | |
} | |
private static Expression RemoveNullable(Expression ex) | |
{ | |
if (IsNullableStruct(ex)) | |
return Expression.Convert(ex, ex.Type.GenericTypeArguments[0]); | |
return ex; | |
} | |
private class ExpressionReplacer : ExpressionVisitor | |
{ | |
private readonly Expression _oldEx; | |
private readonly Expression _newEx; | |
internal ExpressionReplacer(Expression oldEx, Expression newEx) | |
{ | |
_oldEx = oldEx; | |
_newEx = newEx; | |
} | |
public override Expression Visit(Expression node) | |
{ | |
if (node == _oldEx) | |
return _newEx; | |
return base.Visit(node); | |
} | |
} | |
} | |
} | |
} |
@Liero thanks for sharing! I will take a look. Maybe it is time to create a full repository/nuget for this code and receive PR/improvements from community.
@Liero the problem was deeply solved and shared in this repository as well in nuget
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@leandromoh:
I have hacked something that works in my case, but it's not universal.
I have added
VisitLambda
andMakeNullableType
methods:However, I think the problem lies deeper. For example the original visitor is removing Convert expressions and I have to add those manually VisitLambda.
and the tests: