Skip to content

Instantly share code, notes, and snippets.

@leandromoh
Last active December 10, 2023 02:38
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leandromoh/92bec4eddbacb1b20f4a67cd5ac6a081 to your computer and use it in GitHub Desktop.
Save leandromoh/92bec4eddbacb1b20f4a67cd5ac6a081 to your computer and use it in GitHub Desktop.
visitor that modify expression tree for safe null propagation! This code is now maintained in https://github.com/leandromoh/NullPropagationVisitor
using System;
using System.Diagnostics;
using System.Linq.Expressions;
namespace ConsoleApp3
{
static class Program
{
private static string Foo(string s) => s;
static void Main(string[] _)
{
var visitor = new NullPropagationVisitor(recursive: true);
Test1();
Test2();
Test3();
void Test1()
{
Expression<Func<string, char?>> f = s => s == "foo" ? 'X' : Foo(s).Length.ToString()[0];
var fBody = (Expression<Func<string, char?>>)visitor.Visit(f);
var fFunc = fBody.Compile();
Debug.Assert(fFunc(null) == null);
Debug.Assert(fFunc("bar") == '3');
Debug.Assert(fFunc("foo") == 'X');
}
void Test2()
{
Expression<Func<string, int>> y = s => s.Length;
var yBody = visitor.Visit(y.Body);
var yFunc = Expression.Lambda<Func<string, int?>>(
body: yBody,
parameters: y.Parameters)
.Compile();
Debug.Assert(yFunc(null) == null);
Debug.Assert(yFunc("bar") == 3);
}
void Test3()
{
Expression<Func<char?, string>> y = s => s.Value.ToString()[0].ToString();
var yBody = visitor.Visit(y.Body);
var yFunc = Expression.Lambda<Func<char?, string>>(
body: yBody,
parameters: y.Parameters)
.Compile();
Debug.Assert(yFunc(null) == null);
Debug.Assert(yFunc('A') == "A");
}
}
public class NullPropagationVisitor : ExpressionVisitor
{
private readonly bool _recursive;
public NullPropagationVisitor(bool recursive)
{
_recursive = recursive;
}
protected override Expression VisitUnary(UnaryExpression propertyAccess)
{
if (propertyAccess.Operand is MemberExpression mem)
return VisitMember(mem);
if (propertyAccess.Operand is MethodCallExpression met)
return VisitMethodCall(met);
if (propertyAccess.Operand is ConditionalExpression cond)
return Expression.Condition(
test: cond.Test,
ifTrue: MakeNullable(Visit(cond.IfTrue)),
ifFalse: MakeNullable(Visit(cond.IfFalse)));
return base.VisitUnary(propertyAccess);
}
protected override Expression VisitMember(MemberExpression propertyAccess)
{
return Common(propertyAccess.Expression, propertyAccess);
}
protected override Expression VisitMethodCall(MethodCallExpression propertyAccess)
{
if (propertyAccess.Object == null)
return base.VisitMethodCall(propertyAccess);
return Common(propertyAccess.Object, propertyAccess);
}
private BlockExpression Common(Expression instance, Expression propertyAccess)
{
var safe = _recursive ? base.Visit(instance) : instance;
var caller = Expression.Variable(safe.Type, "caller");
var assign = Expression.Assign(caller, safe);
var acess = MakeNullable(new ExpressionReplacer(instance,
IsNullableStruct(instance) ? caller : RemoveNullable(caller)).Visit(propertyAccess));
var ternary = Expression.Condition(
test: Expression.Equal(caller, Expression.Constant(null)),
ifTrue: Expression.Constant(null, acess.Type),
ifFalse: acess);
return Expression.Block(
type: acess.Type,
variables: new[]
{
caller,
},
expressions: new Expression[]
{
assign,
ternary,
});
}
private static Expression MakeNullable(Expression ex)
{
if (IsNullable(ex))
return ex;
return Expression.Convert(ex, typeof(Nullable<>).MakeGenericType(ex.Type));
}
private static bool IsNullable(Expression ex)
{
return !ex.Type.IsValueType || (Nullable.GetUnderlyingType(ex.Type) != null);
}
private static bool IsNullableStruct(Expression ex)
{
return ex.Type.IsValueType && (Nullable.GetUnderlyingType(ex.Type) != null);
}
private static Expression RemoveNullable(Expression ex)
{
if (IsNullableStruct(ex))
return Expression.Convert(ex, ex.Type.GenericTypeArguments[0]);
return ex;
}
private class ExpressionReplacer : ExpressionVisitor
{
private readonly Expression _oldEx;
private readonly Expression _newEx;
internal ExpressionReplacer(Expression oldEx, Expression newEx)
{
_oldEx = oldEx;
_newEx = newEx;
}
public override Expression Visit(Expression node)
{
if (node == _oldEx)
return _newEx;
return base.Visit(node);
}
}
}
}
}
@leandromoh
Copy link
Author

@Liero thanks for sharing! I will take a look. Maybe it is time to create a full repository/nuget for this code and receive PR/improvements from community.

@leandromoh
Copy link
Author

@Liero the problem was deeply solved and shared in this repository as well in nuget

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment