Skip to content

Instantly share code, notes, and snippets.

@LPeter1997
Last active October 30, 2022 14:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save LPeter1997/47886848e145fb61eb49c32344c03abf to your computer and use it in GitHub Desktop.
Save LPeter1997/47886848e145fb61eb49c32344c03abf to your computer and use it in GitHub Desktop.
REEEEE
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
/// <summary>
/// Compares two <see cref="IAsyncStateMachine"/>s for value-equality.
/// </summary>
internal sealed class AsmComparer : IEqualityComparer<IAsyncStateMachine>
{
/// <summary>
/// A singleton instance of this comparer.
/// </summary>
public static AsmComparer Instance { get; } = new();
private AsmComparer()
{
}
public bool Equals(IAsyncStateMachine? x, IAsyncStateMachine? y) =>
AsmComparerCache.Equals(x, y);
public int GetHashCode([DisallowNull] IAsyncStateMachine obj) =>
AsmComparerCache.GetHashCode(obj);
}
/// <summary>
/// Implements the actual comparisons for async state machines with compiled expression trees.
/// </summary>
internal sealed class AsmComparerCache
{
private delegate bool AsmEqualsDelegate(IAsyncStateMachine x, IAsyncStateMachine y);
private delegate int AsmGetHashCodeDelegate(IAsyncStateMachine obj);
private static readonly ConcurrentDictionary<Type, AsmEqualsDelegate> equalsInstances = new();
private static readonly ConcurrentDictionary<Type, AsmGetHashCodeDelegate> hashCodeInstances = new();
public static bool Equals(IAsyncStateMachine? x, IAsyncStateMachine? y)
{
if (ReferenceEquals(x, y)) return true;
if (x is null || y is null) return false;
var t1 = x.GetType();
var t2 = y.GetType();
if (t1 != t2) return false;
var equals = equalsInstances.GetOrAdd(t1, CreateEqualsFunc);
return equals(x, y);
}
public static int GetHashCode([DisallowNull] IAsyncStateMachine obj)
{
var type = obj.GetType();
var hashCode = hashCodeInstances.GetOrAdd(type, CreateHashCodeFunc);
return hashCode(obj);
}
private static AsmEqualsDelegate CreateEqualsFunc(Type asmType)
{
var getTypeMethod = asmType.GetMethod(nameof(GetType))!;
var asmTypeInArray = new[] { asmType };
var param1 = Expression.Parameter(typeof(IAsyncStateMachine));
var param2 = Expression.Parameter(typeof(IAsyncStateMachine));
var unsafeAsParam1 = Expression.Call(
type: typeof(Unsafe),
methodName: nameof(Unsafe.As),
typeArguments: asmTypeInArray,
arguments: param1);
var unsafeAsParam2 = Expression.Call(
type: typeof(Unsafe),
methodName: nameof(Unsafe.As),
typeArguments: asmTypeInArray,
arguments: param2);
var comparisons = GetRelevantFields(asmType)
.Select(f => Expression.Equal(
Expression.MakeMemberAccess(unsafeAsParam1, f),
Expression.MakeMemberAccess(unsafeAsParam2, f)));
var comparisonsConjuncted = comparisons
.Cast<Expression>()
.Prepend(Expression.Constant(true))
.Aggregate(Expression.AndAlso);
var lambda = Expression.Lambda(comparisonsConjuncted, new[] { param1, param2 });
return new((Func<IAsyncStateMachine, IAsyncStateMachine, bool>)lambda.Compile());
}
private static AsmGetHashCodeDelegate CreateHashCodeFunc(Type asmType)
{
var getTypeMethod = asmType.GetMethod(nameof(GetType))!;
var asmTypeInArray = new[] { asmType };
var param = Expression.Parameter(typeof(IAsyncStateMachine));
var hashCombineArgs = new List<Expression>();
var hashCombineTypeArgs = new List<Type>();
hashCombineArgs.Add(Expression.Call(param, getTypeMethod));
hashCombineTypeArgs.Add(typeof(Type));
var unsafeAs = Expression.Call(
type: typeof(Unsafe),
methodName: nameof(Unsafe.As),
typeArguments: asmTypeInArray,
arguments: param);
foreach (var field in GetRelevantFields(asmType))
{
hashCombineArgs.Add(Expression.MakeMemberAccess(unsafeAs, field));
hashCombineTypeArgs.Add(field.FieldType);
}
var hashCombineCall = Expression.Call(
type: typeof(HashCode),
methodName: nameof(HashCode.Combine),
typeArguments: hashCombineTypeArgs.ToArray(),
arguments: hashCombineArgs.ToArray());
var lambda = Expression.Lambda(hashCombineCall, param);
return new((Func<IAsyncStateMachine, int>)lambda.Compile());
}
private static IEnumerable<FieldInfo> GetRelevantFields(Type asmType) => asmType
.GetFields()
.Where(f => !f.Name.Contains('<'));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment