Skip to content

Instantly share code, notes, and snippets.

@leandromoh
Last active December 10, 2023 02:38
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leandromoh/92bec4eddbacb1b20f4a67cd5ac6a081 to your computer and use it in GitHub Desktop.
Save leandromoh/92bec4eddbacb1b20f4a67cd5ac6a081 to your computer and use it in GitHub Desktop.
visitor that modify expression tree for safe null propagation! This code is now maintained in https://github.com/leandromoh/NullPropagationVisitor
using System;
using System.Diagnostics;
using System.Linq.Expressions;
namespace ConsoleApp3
{
static class Program
{
private static string Foo(string s) => s;
static void Main(string[] _)
{
var visitor = new NullPropagationVisitor(recursive: true);
Test1();
Test2();
Test3();
void Test1()
{
Expression<Func<string, char?>> f = s => s == "foo" ? 'X' : Foo(s).Length.ToString()[0];
var fBody = (Expression<Func<string, char?>>)visitor.Visit(f);
var fFunc = fBody.Compile();
Debug.Assert(fFunc(null) == null);
Debug.Assert(fFunc("bar") == '3');
Debug.Assert(fFunc("foo") == 'X');
}
void Test2()
{
Expression<Func<string, int>> y = s => s.Length;
var yBody = visitor.Visit(y.Body);
var yFunc = Expression.Lambda<Func<string, int?>>(
body: yBody,
parameters: y.Parameters)
.Compile();
Debug.Assert(yFunc(null) == null);
Debug.Assert(yFunc("bar") == 3);
}
void Test3()
{
Expression<Func<char?, string>> y = s => s.Value.ToString()[0].ToString();
var yBody = visitor.Visit(y.Body);
var yFunc = Expression.Lambda<Func<char?, string>>(
body: yBody,
parameters: y.Parameters)
.Compile();
Debug.Assert(yFunc(null) == null);
Debug.Assert(yFunc('A') == "A");
}
}
public class NullPropagationVisitor : ExpressionVisitor
{
private readonly bool _recursive;
public NullPropagationVisitor(bool recursive)
{
_recursive = recursive;
}
protected override Expression VisitUnary(UnaryExpression propertyAccess)
{
if (propertyAccess.Operand is MemberExpression mem)
return VisitMember(mem);
if (propertyAccess.Operand is MethodCallExpression met)
return VisitMethodCall(met);
if (propertyAccess.Operand is ConditionalExpression cond)
return Expression.Condition(
test: cond.Test,
ifTrue: MakeNullable(Visit(cond.IfTrue)),
ifFalse: MakeNullable(Visit(cond.IfFalse)));
return base.VisitUnary(propertyAccess);
}
protected override Expression VisitMember(MemberExpression propertyAccess)
{
return Common(propertyAccess.Expression, propertyAccess);
}
protected override Expression VisitMethodCall(MethodCallExpression propertyAccess)
{
if (propertyAccess.Object == null)
return base.VisitMethodCall(propertyAccess);
return Common(propertyAccess.Object, propertyAccess);
}
private BlockExpression Common(Expression instance, Expression propertyAccess)
{
var safe = _recursive ? base.Visit(instance) : instance;
var caller = Expression.Variable(safe.Type, "caller");
var assign = Expression.Assign(caller, safe);
var acess = MakeNullable(new ExpressionReplacer(instance,
IsNullableStruct(instance) ? caller : RemoveNullable(caller)).Visit(propertyAccess));
var ternary = Expression.Condition(
test: Expression.Equal(caller, Expression.Constant(null)),
ifTrue: Expression.Constant(null, acess.Type),
ifFalse: acess);
return Expression.Block(
type: acess.Type,
variables: new[]
{
caller,
},
expressions: new Expression[]
{
assign,
ternary,
});
}
private static Expression MakeNullable(Expression ex)
{
if (IsNullable(ex))
return ex;
return Expression.Convert(ex, typeof(Nullable<>).MakeGenericType(ex.Type));
}
private static bool IsNullable(Expression ex)
{
return !ex.Type.IsValueType || (Nullable.GetUnderlyingType(ex.Type) != null);
}
private static bool IsNullableStruct(Expression ex)
{
return ex.Type.IsValueType && (Nullable.GetUnderlyingType(ex.Type) != null);
}
private static Expression RemoveNullable(Expression ex)
{
if (IsNullableStruct(ex))
return Expression.Convert(ex, ex.Type.GenericTypeArguments[0]);
return ex;
}
private class ExpressionReplacer : ExpressionVisitor
{
private readonly Expression _oldEx;
private readonly Expression _newEx;
internal ExpressionReplacer(Expression oldEx, Expression newEx)
{
_oldEx = oldEx;
_newEx = newEx;
}
public override Expression Visit(Expression node)
{
if (node == _oldEx)
return _newEx;
return base.Visit(node);
}
}
}
}
}
@Liero
Copy link

Liero commented May 4, 2021

This throws exception, when the original LambdaExpression has non nullable return type. E.g.

Expression<Func<string, double>> lambda = str => str.Length;

Error: System.ArgumentException: Expression of type 'System.Nullable`1[System.Int32]' cannot be used for return type 'System.Int32'

EDIT:

Resolved using

    protected override Expression VisitLambda<T>(Expression<T> node)
    {
        return Expression.Lambda(Visit(node.Body), node.Parameters.ToArray());
    }

@leandromoh
Copy link
Author

leandromoh commented May 5, 2021

Just for note (not related to previous comment), as well the ?. operator, the visitor evaluates the left-hand operand only once. So

Expression<Func<string, char?>> f = s => s == "foo" ? 'X' : Foo(s).Length.ToString()[0];

becomes something similar to

image

@leandromoh
Copy link
Author

@Liero can you share the code version that worked for you, please?

@Liero
Copy link

Liero commented May 6, 2021

@leandromoh:

I have hacked something that works in my case, but it's not universal.

I have added VisitLambda and MakeNullableType methods:

However, I think the problem lies deeper. For example the original visitor is removing Convert expressions and I have to add those manually VisitLambda.

public class NullPropagationVisitor : ExpressionVisitor
{
    private readonly bool _recursive;

    public NullPropagationVisitor(bool recursive)
    {
        _recursive = recursive;
    }

    protected override Expression VisitLambda<T>(Expression<T> node)
    {
        var body = Visit(node.Body);
        var expectedReturnType = MakeNullableType(node.ReturnType);
        if (body.Type != expectedReturnType)
        {
            body = Expression.Convert(body, expectedReturnType);
        }
        return Expression.Lambda(body, node.Parameters);
    }

    protected override Expression VisitUnary(UnaryExpression propertyAccess)
    {
        if (propertyAccess.Operand is MemberExpression mem)
            return VisitMember(mem);

        if (propertyAccess.Operand is MethodCallExpression met)
            return VisitMethodCall(met);

        if (propertyAccess.Operand is ConditionalExpression cond)
            return Expression.Condition(
                    test: cond.Test,
                    ifTrue: MakeNullable(Visit(cond.IfTrue)),
                    ifFalse: MakeNullable(Visit(cond.IfFalse)));

        return base.VisitUnary(propertyAccess);
    }

    protected override Expression VisitMember(MemberExpression propertyAccess)
    {
        if (propertyAccess.Expression == null)
        {
            throw new ArgumentException($"The parameter member {propertyAccess}.{propertyAccess.Expression} cannot be null", nameof(propertyAccess))
        }
        return Common(propertyAccess.Expression, propertyAccess);
    }

    protected override Expression VisitMethodCall(MethodCallExpression propertyAccess)
    {
        if (propertyAccess.Object == null)
            return base.VisitMethodCall(propertyAccess);

        return Common(propertyAccess.Object, propertyAccess);
    }

    private Expression Common(Expression instance, Expression propertyAccess)
    {
        
        var safe = _recursive ? base.Visit(instance) : instance;
        var caller = Expression.Variable(safe.Type, "caller");
        var assign = Expression.Assign(caller, safe);
        var acess = MakeNullable(new ExpressionReplacer(instance,
            IsNullableStruct(instance) ? caller : RemoveNullable(caller)).Visit(propertyAccess)!);
        var ternary = Expression.Condition(
                    test: Expression.Equal(caller, Expression.Constant(null)),
                    ifTrue: Expression.Constant(null, acess.Type),
                    ifFalse: acess);

        return Expression.Block(
            type: acess.Type,
            variables: new[]
            {
            caller,
            },
            expressions: new Expression[]
            {
            assign,
            ternary,
            });
    }

    public static Type MakeNullableType(Type type)
    {
        return type.IsValueType && Nullable.GetUnderlyingType(type) == null
            ? typeof(Nullable<>).MakeGenericType(type)
            : type;
    }

    public static Expression MakeNullable(Expression ex)
    {
        if (IsNullable(ex))
            return ex;

        return Expression.Convert(ex, typeof(Nullable<>).MakeGenericType(ex.Type));
    }

    private static bool IsNullable(Expression ex)
    {
        return !ex.Type.IsValueType || Nullable.GetUnderlyingType(ex.Type) != null;
    }

    private static bool IsNullableStruct(Expression ex)
    {
        return ex.Type.IsValueType && Nullable.GetUnderlyingType(ex.Type) != null;
    }

    private static Expression RemoveNullable(Expression ex)
    {
        if (IsNullableStruct(ex))
            return Expression.Convert(ex, ex.Type.GenericTypeArguments[0]);

        return ex;
    }

    private class ExpressionReplacer : ExpressionVisitor
    {
        private readonly Expression _oldEx;
        private readonly Expression _newEx;

        internal ExpressionReplacer(Expression oldEx, Expression newEx)
        {
            _oldEx = oldEx;
            _newEx = newEx;
        }

        public override Expression? Visit(Expression? node)
        {
            if (node == _oldEx)
                return _newEx;

            return base.Visit(node);
        }
    }
}

and the tests:

#nullable enable
#pragma warning disable CS8602 // Dereference of a possibly null reference.

[TestClass]
public class NullPropagationVisitorTests
{
    record Foo(string? Name, Foo? Child = null);

    [TestMethod]
    public void NullPropagationVisitor_AddsNullChecks()
    {

        Expression<Func<Foo, char>> expression = foo => foo.Child.Name.ToString()[0];


        //sut
        var visitor = new NullPropagationVisitor(true);

        //action
        var expressionWithNullChecks = (Expression<Func<Foo, char?>>)visitor.Visit(expression);
        var funcGetChildFirstChar = expressionWithNullChecks.Compile();

        Assert.IsNull(funcGetChildFirstChar(new Foo("parent")));
        Assert.IsNull(funcGetChildFirstChar(new Foo("parent", new Foo(null))));
        Assert.AreEqual('c', funcGetChildFirstChar(new Foo("parent", new Foo("child"))));
    }


    [TestMethod]
    public void NullPropagationVisitor_LambdasHasCorrectSignature()
    {
        Expression<Func<Foo, object>> expression = foo => foo.Child.Name.ToString()[0];
        Type expectedType = typeof(Func<Foo, object?>);
        
        //sut
        var visitor = new NullPropagationVisitor(true);

        //action
        LambdaExpression expressionWithNullChecks = (LambdaExpression)visitor.Visit(expression);

        Assert.AreEqual(expectedType, expressionWithNullChecks.Type);
    }
}

@leandromoh
Copy link
Author

@Liero thanks for sharing! I will take a look. Maybe it is time to create a full repository/nuget for this code and receive PR/improvements from community.

@leandromoh
Copy link
Author

@Liero the problem was deeply solved and shared in this repository as well in nuget

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment