Last active
October 30, 2022 14:33
-
-
Save LPeter1997/47886848e145fb61eb49c32344c03abf to your computer and use it in GitHub Desktop.
REEEEE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Concurrent; | |
using System.Collections.Generic; | |
using System.Diagnostics.CodeAnalysis; | |
using System.Linq; | |
using System.Linq.Expressions; | |
using System.Reflection; | |
using System.Runtime.CompilerServices; | |
/// <summary> | |
/// Compares two <see cref="IAsyncStateMachine"/>s for value-equality. | |
/// </summary> | |
internal sealed class AsmComparer : IEqualityComparer<IAsyncStateMachine> | |
{ | |
/// <summary> | |
/// A singleton instance of this comparer. | |
/// </summary> | |
public static AsmComparer Instance { get; } = new(); | |
private AsmComparer() | |
{ | |
} | |
public bool Equals(IAsyncStateMachine? x, IAsyncStateMachine? y) => | |
AsmComparerCache.Equals(x, y); | |
public int GetHashCode([DisallowNull] IAsyncStateMachine obj) => | |
AsmComparerCache.GetHashCode(obj); | |
} | |
/// <summary> | |
/// Implements the actual comparisons for async state machines with compiled expression trees. | |
/// </summary> | |
internal sealed class AsmComparerCache | |
{ | |
private delegate bool AsmEqualsDelegate(IAsyncStateMachine x, IAsyncStateMachine y); | |
private delegate int AsmGetHashCodeDelegate(IAsyncStateMachine obj); | |
private static readonly ConcurrentDictionary<Type, AsmEqualsDelegate> equalsInstances = new(); | |
private static readonly ConcurrentDictionary<Type, AsmGetHashCodeDelegate> hashCodeInstances = new(); | |
public static bool Equals(IAsyncStateMachine? x, IAsyncStateMachine? y) | |
{ | |
if (ReferenceEquals(x, y)) return true; | |
if (x is null || y is null) return false; | |
var t1 = x.GetType(); | |
var t2 = y.GetType(); | |
if (t1 != t2) return false; | |
var equals = equalsInstances.GetOrAdd(t1, CreateEqualsFunc); | |
return equals(x, y); | |
} | |
public static int GetHashCode([DisallowNull] IAsyncStateMachine obj) | |
{ | |
var type = obj.GetType(); | |
var hashCode = hashCodeInstances.GetOrAdd(type, CreateHashCodeFunc); | |
return hashCode(obj); | |
} | |
private static AsmEqualsDelegate CreateEqualsFunc(Type asmType) | |
{ | |
var getTypeMethod = asmType.GetMethod(nameof(GetType))!; | |
var asmTypeInArray = new[] { asmType }; | |
var param1 = Expression.Parameter(typeof(IAsyncStateMachine)); | |
var param2 = Expression.Parameter(typeof(IAsyncStateMachine)); | |
var unsafeAsParam1 = Expression.Call( | |
type: typeof(Unsafe), | |
methodName: nameof(Unsafe.As), | |
typeArguments: asmTypeInArray, | |
arguments: param1); | |
var unsafeAsParam2 = Expression.Call( | |
type: typeof(Unsafe), | |
methodName: nameof(Unsafe.As), | |
typeArguments: asmTypeInArray, | |
arguments: param2); | |
var comparisons = GetRelevantFields(asmType) | |
.Select(f => Expression.Equal( | |
Expression.MakeMemberAccess(unsafeAsParam1, f), | |
Expression.MakeMemberAccess(unsafeAsParam2, f))); | |
var comparisonsConjuncted = comparisons | |
.Cast<Expression>() | |
.Prepend(Expression.Constant(true)) | |
.Aggregate(Expression.AndAlso); | |
var lambda = Expression.Lambda(comparisonsConjuncted, new[] { param1, param2 }); | |
return new((Func<IAsyncStateMachine, IAsyncStateMachine, bool>)lambda.Compile()); | |
} | |
private static AsmGetHashCodeDelegate CreateHashCodeFunc(Type asmType) | |
{ | |
var getTypeMethod = asmType.GetMethod(nameof(GetType))!; | |
var asmTypeInArray = new[] { asmType }; | |
var param = Expression.Parameter(typeof(IAsyncStateMachine)); | |
var hashCombineArgs = new List<Expression>(); | |
var hashCombineTypeArgs = new List<Type>(); | |
hashCombineArgs.Add(Expression.Call(param, getTypeMethod)); | |
hashCombineTypeArgs.Add(typeof(Type)); | |
var unsafeAs = Expression.Call( | |
type: typeof(Unsafe), | |
methodName: nameof(Unsafe.As), | |
typeArguments: asmTypeInArray, | |
arguments: param); | |
foreach (var field in GetRelevantFields(asmType)) | |
{ | |
hashCombineArgs.Add(Expression.MakeMemberAccess(unsafeAs, field)); | |
hashCombineTypeArgs.Add(field.FieldType); | |
} | |
var hashCombineCall = Expression.Call( | |
type: typeof(HashCode), | |
methodName: nameof(HashCode.Combine), | |
typeArguments: hashCombineTypeArgs.ToArray(), | |
arguments: hashCombineArgs.ToArray()); | |
var lambda = Expression.Lambda(hashCombineCall, param); | |
return new((Func<IAsyncStateMachine, int>)lambda.Compile()); | |
} | |
private static IEnumerable<FieldInfo> GetRelevantFields(Type asmType) => asmType | |
.GetFields() | |
.Where(f => !f.Name.Contains('<')); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment