Skip to content

Instantly share code, notes, and snippets.

@bboyle1234
Last active March 23, 2016 05:35
Show Gist options
  • Save bboyle1234/0cadf132bbc9f00db4ff to your computer and use it in GitHub Desktop.
Save bboyle1234/0cadf132bbc9f00db4ff to your computer and use it in GitHub Desktop.
MemberwiseEqualityObject
using System;
using System.Linq;
using System.Collections.Generic;
namespace ApexInvesting.Platform {
public static class EnumerableExtensions {
/// <summary>
/// Convert Enumerable to HashSet.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="value"></param>
/// <returns></returns>
public static HashSet<T> ToHashSet<T>(this IEnumerable<T> value) {
if (null == value) {
throw new ArgumentNullException("value");
}
return new HashSet<T>(value);
}
/// <summary>
/// Fire action on each item in enumerable.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="target"></param>
/// <param name="action"></param>
public static void Each<T>(this IEnumerable<T> target, Action<T> action) {
if (null == target) {
throw new ArgumentNullException("target");
}
if (null == action) {
throw new ArgumentNullException("action");
}
foreach (var item in target) {
action(item);
}
}
public static bool SequenceEqualWithNullChecking<T>(this IEnumerable<T> target, IEnumerable<T> other) {
if (null == target)
return null == other;
if (null == other)
return false;
return target.SequenceEqual(other);
}
public static int GetEnumeratedHashCode<T>(this IEnumerable<T> target) {
if (null == target)
return 0;
var result = 17;
foreach (var element in target) {
if (null != element)
result *= 31 + element.GetHashCode();
}
return result;
}
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace ApexInvesting.Platform.Utilities {
public class MemberwiseEqualityIgnoreAttribute : Attribute {
}
}
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
namespace ApexInvesting.Platform.Utilities {
/// <summary>
/// Base class that provides a dynamic, memberwise Equals/GetHashCode
/// implementation. Rather than using reflection, methods are created
/// using compiled expression trees for faster execution.
/// </summary>
/// <see cref="http://www.brad-smith.info/blog/archives/385"/>
public class MemberwiseEqualityObject {
static Dictionary<Type, MemberwiseFunctions> _functions = new Dictionary<Type, MemberwiseFunctions>();
/// <summary>
/// Used to hold delegates for the compiled methods.
/// </summary>
private class MemberwiseFunctions {
/// <summary>
/// Delegate for the Equals method.
/// </summary>
public Func<object, object, bool> EqualsFunc;
/// <summary>
/// Delegate for the GetHashCode method.
/// </summary>
public Func<object, int> GetHashCodeFunc;
}
/// <summary>
/// Creates the GetHashCode() method.
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
static Func<object, int> MakeGetHashCodeMethod(Type type) {
ParameterExpression pThis = Expression.Parameter(typeof(object), "x");
UnaryExpression pCastThis = Expression.Convert(pThis, type);
Expression result = null;
foreach (FieldInfo field in GetComparableFields(type)) {
var hash = MakeGetHashCodeExpression(pCastThis, field);
result = null == result
? (Expression)hash
: (Expression)Expression.ExclusiveOr(result, hash);
}
foreach (PropertyInfo property in GetComparableProperties(type)) {
var hash = MakeGetHashCodeExpression(pCastThis, property);
result = null == result
? (Expression)hash
: (Expression)Expression.ExclusiveOr(result, hash);
}
return Expression.Lambda<Func<object, int>>(result, pThis).Compile();
}
/// <summary>
/// Creates the Equals() method.
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
static Func<object, object, bool> MakeEqualsMethod(Type type) {
ParameterExpression pThis = Expression.Parameter(typeof(object), "x");
ParameterExpression pThat = Expression.Parameter(typeof(object), "y");
// cast to the subclass type
UnaryExpression pCastThis = Expression.Convert(pThis, type);
UnaryExpression pCastThat = Expression.Convert(pThat, type);
// compound AND expression using short-circuit evaluation
Expression result = null;
foreach (FieldInfo field in GetComparableFields(type)) {
Expression fieldEquals = MakeEqualsExpression(pCastThis, pCastThat, field);
result = null == result
? fieldEquals
: Expression.AndAlso(result, fieldEquals);
}
foreach (PropertyInfo property in GetComparableProperties(type)) {
Expression propertyEquals = MakeEqualsExpression(pCastThis, pCastThat, property);
result = null == result
? propertyEquals
: Expression.AndAlso(result, propertyEquals);
}
// call Object.Equals if second parameter doesn't match type
result = Expression.Condition(
Expression.TypeIs(pThat, type),
result,
Expression.Equal(pThis, pThat)
);
// compile method
return Expression.Lambda<Func<object, object, bool>>(result, pThis, pThat).Compile();
}
static Expression MakeEqualsExpression(Expression thisObject, Expression thatObject, FieldInfo fieldInfo) {
var thisField = Expression.Field(thisObject, fieldInfo);
var thatField = Expression.Field(thatObject, fieldInfo);
if (fieldInfo.FieldType.IsValueType || fieldInfo.FieldType.Equals(typeof(string))) {
return Expression.Equal(thisField, thatField);
} else if (typeof(IEnumerable).IsAssignableFrom(fieldInfo.FieldType)) {
var equalsMethod = CreateSequenceEqualsMethod(fieldInfo.FieldType);
return Expression.Call(equalsMethod, thisField, thatField);
} else {
var equalsMethod = fieldInfo.FieldType.GetMethod("Equals", new[] { typeof(object) });
return Expression.Call(thisField, equalsMethod, thatField);
}
}
static Expression MakeEqualsExpression(Expression thisObject, Expression otherObject, PropertyInfo propertyInfo) {
var thisProperty = Expression.Property(thisObject, propertyInfo);
var otherProperty = Expression.Property(otherObject, propertyInfo);
if (propertyInfo.PropertyType.IsValueType || propertyInfo.PropertyType.Equals(typeof(string))) {
return Expression.Equal(thisProperty, otherProperty);
} else if (typeof(IEnumerable).IsAssignableFrom(propertyInfo.PropertyType)) {
var equalsMethod = CreateSequenceEqualsMethod(propertyInfo.PropertyType);
return Expression.Call(equalsMethod, thisProperty, otherProperty);
} else {
var equalsMethod = propertyInfo.PropertyType.GetMethod("Equals", new[] { typeof(object) });
return Expression.Call(thisProperty, equalsMethod, otherProperty);
}
}
static MethodCallExpression MakeGetHashCodeExpression(Expression thisObject, FieldInfo fieldInfo) {
var thisField = Expression.Field(thisObject, fieldInfo);
if (typeof(IEnumerable).IsAssignableFrom(fieldInfo.FieldType) && !fieldInfo.FieldType.Equals(typeof(string))) {
var hashMethod = CreateGetEnumeratedHashCodeMethod(fieldInfo.FieldType);
return Expression.Call(hashMethod, thisField);
} else {
return Expression.Call(thisField, "GetHashCode", Type.EmptyTypes);
}
}
static MethodCallExpression MakeGetHashCodeExpression(Expression thisObject, PropertyInfo propertyInfo) {
var thisProperty = Expression.Property(thisObject, propertyInfo);
if (typeof(IEnumerable).IsAssignableFrom(propertyInfo.PropertyType) && !propertyInfo.PropertyType.Equals(typeof(string))) {
var hashMethod = CreateGetEnumeratedHashCodeMethod(propertyInfo.PropertyType);
return Expression.Call(hashMethod, thisProperty);
} else {
return Expression.Call(thisProperty, "GetHashCode", Type.EmptyTypes);
}
}
static MethodInfo CreateSequenceEqualsMethod(Type iEnumerableType) {
var elementType = GetElementType(iEnumerableType);
return typeof(EnumerableExtensions).GetMethod("SequenceEqualWithNullChecking").MakeGenericMethod(elementType);
}
static MethodInfo CreateGetEnumeratedHashCodeMethod(Type iEnumerableType) {
var elementType = GetElementType(iEnumerableType);
return typeof(EnumerableExtensions).GetMethod("GetEnumeratedHashCode").MakeGenericMethod(elementType);
}
static Type GetElementType(Type iEnumerableType) {
return iEnumerableType.IsArray ? iEnumerableType.GetElementType() : iEnumerableType.GetGenericArguments()[0];
}
static IEnumerable<FieldInfo> GetComparableFields(Type type) {
foreach (var field in type.GetFields()) {
if (field.GetCustomAttributes(typeof(MemberwiseEqualityIgnoreAttribute), true).Length > 0)
continue;
yield return field;
}
}
static IEnumerable<PropertyInfo> GetComparableProperties(Type type) {
foreach (var property in type.GetProperties()) {
if (property.GetCustomAttributes(typeof(MemberwiseEqualityIgnoreAttribute), true).Length > 0)
continue;
yield return property;
}
}
/// <summary>
/// Dynamically compiles the Equals/GetHashCode functions on the
/// first call to a subclass constructor.
/// </summary>
public MemberwiseEqualityObject() {
var type = GetType();
if (!_functions.ContainsKey(type)) {
if (type.IsSubclassOf(typeof(IEnumerable)))
throw new MemberwiseEqualityObjectException("MemberwiseEquality checking does not work on objects that inherit from IEnumerable");
if (typeof(IEnumerable).IsAssignableFrom(type))
throw new MemberwiseEqualityObjectException("MemberwiseEquality checking does not work on objects that inherit from IEnumerable");
var funcs = new MemberwiseFunctions {
EqualsFunc = MakeEqualsMethod(type),
GetHashCodeFunc = MakeGetHashCodeMethod(type),
};
_functions[type] = funcs;
}
}
/// <summary>
/// Returns the member-wise hash code for this instance.
/// </summary>
/// <returns></returns>
public override int GetHashCode() {
return _functions[GetType()].GetHashCodeFunc(this);
}
/// <summary>
/// Determines whether two instances are equal, using a member-wise comparison.
/// </summary>
/// <param name="obj"></param>
/// <returns></returns>
public override bool Equals(object obj) {
return _functions[GetType()].EqualsFunc(this, obj);
}
}
}
using ApexInvesting.Platform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace ApexInvesting.Platform.Test {
[TestClass]
public class MemberwiseEqualityObjectTests {
[TestMethod]
public void MemberwiseEqualityObject_01() {
Apple a, b;
// checking general equality
Reset(out a, out b);
ExpectEqual(a, b);
// checking ignore attribute on fields
a.IgnoredField = 23;
ExpectEqual(a, b);
// checking ignore attribute on properties
a.IgnoredProperty = 55;
ExpectEqual(a, b);
// checking that enumerable fields work
b.EnumerableField = new[] { 1, 2, 3 };
ExpectNotEqual(a, b);
// checking that enumerable properties work
Reset(out a, out b);
b.EnumerableProperty = new[] { 1 };
ExpectNotEqual(a, b);
// checking IEnumerable<MemberwiseObject> also works.
Reset(out a, out b);
a.EnumerableMemberwiseProperty.Add(new Banana { B = 5, C = 6 });
ExpectNotEqual(a, b);
// checking string property stuff
Reset(out a, out b);
a.StringField = "a";
ExpectNotEqual(a, b);
}
void Reset(out Apple a, out Apple b) {
a = CreateDefaultApple();
b = CreateDefaultApple();
}
void ExpectEqual(MemberwiseEqualityObject a, MemberwiseEqualityObject b) {
Assert.IsTrue(a.Equals(b));
Assert.AreEqual(a.GetHashCode(), b.GetHashCode());
}
void ExpectNotEqual(MemberwiseEqualityObject a, MemberwiseEqualityObject b) {
Assert.IsFalse(a.Equals(b));
Assert.AreNotEqual(a.GetHashCode(), b.GetHashCode());
}
Apple CreateDefaultApple() {
return new Apple {
IgnoredField = 1,
IgnoredProperty = 1,
EnumerableField = new[] { 1, 2 },
EnumerableProperty = null,
EnumerableMemberwiseField = new List<Banana> {
new Banana { B = 1, C = 1 },
new Banana { B = 1, C = 1 },
},
EnumerableMemberwiseProperty = new List<Banana> {
new Banana { B = 1, C = 1 },
new Banana { B = 1, C = 1 },
},
StringField = "abcd",
StringProperty = "abcd",
};
}
class Apple : MemberwiseEqualityObject {
[MemberwiseEqualityIgnore]
public int IgnoredField;
[MemberwiseEqualityIgnore]
public int IgnoredProperty { get; set; }
public int[] EnumerableField;
public int[] EnumerableProperty { get; set; }
public List<Banana> EnumerableMemberwiseField;
public List<Banana> EnumerableMemberwiseProperty;
public string StringField;
public string StringProperty { get; set; }
}
class Banana : MemberwiseEqualityObject {
public int B;
public int C { get; set; }
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment