Skip to content

Instantly share code, notes, and snippets.

@matthid
Created February 21, 2022 23:54
Show Gist options
  • Save matthid/6751b17e632a60936356c8fa72c8f4ff to your computer and use it in GitHub Desktop.
Save matthid/6751b17e632a60936356c8fa72c8f4ff to your computer and use it in GitHub Desktop.
Simple wrappers to make Lists and Dictionaries comparable.
using System.Collections;
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
namespace UnorderedCollectionComparison;
public interface IFullComparer<in T> : IEqualityComparer<T>, IComparer<T>
{
}
public static class FullComparer
{
public static IFullComparer<T> Combine<T>(IEqualityComparer<T> equalityComparer, IComparer<T> comparer)
{
return new CombinedFullComparer<T>(equalityComparer, comparer);
}
private class CombinedFullComparer<T> : IFullComparer<T>
{
private readonly IEqualityComparer<T> _equalityComparer;
private readonly IComparer<T> _comparer;
public CombinedFullComparer(IEqualityComparer<T> equalityComparer, IComparer<T> comparer)
{
_equalityComparer = equalityComparer;
_comparer = comparer;
}
public int Compare(T? x, T? y)
{
return _comparer.Compare(x, y);
}
public bool Equals(T? x, T? y)
{
return _equalityComparer.Equals(x, y);
}
public int GetHashCode(T obj)
{
return _equalityComparer.GetHashCode(obj);
}
}
}
public static class FullComparer<T>
{
private static IFullComparer<T> _default = FullComparer.Combine<T>(
EqualityComparer<T>.Default, Comparer<T>.Default);
public static IFullComparer<T> Default => _default;
}
public class UnorderedCollectionComparer<T> : IFullComparer<IReadOnlyCollection<T>>
{
private readonly IFullComparer<T> _comparer;
private readonly ConditionalWeakTable<IReadOnlyCollection<T>, IReadOnlyList<T>> _orderedCache = new();
private IReadOnlyList<T> GetOrdered(IReadOnlyCollection<T> unordered)
{
if (!_orderedCache.TryGetValue(unordered, out var ordered))
{
ordered = unordered.OrderBy(t => t, _comparer).ToList();
_orderedCache.Add(unordered, ordered);
}
return ordered;
}
public UnorderedCollectionComparer(IFullComparer<T> fullComparer)
{
_comparer = fullComparer;
}
public bool Equals(IReadOnlyCollection<T>? x, IReadOnlyCollection<T>? y)
{
if (ReferenceEquals(x, y)) return true;
if (ReferenceEquals(x, null)) return false;
if (ReferenceEquals(y, null)) return false;
if (x.Count != y.Count) return false;
if (GetHashCode(x) != GetHashCode(y)) return false;
var orderedX = GetOrdered(x);
var orderedY = GetOrdered(y);
for (int i = 0; i < orderedX.Count; i++)
{
if (!_comparer.Equals(orderedX[i], orderedY[i]))
{
return false;
}
}
return true;
}
public int GetHashCode(IReadOnlyCollection<T> obj)
{
var code = new HashCode();
code.Add(obj.Count);
// See https://stackoverflow.com/questions/30734848/order-independent-hash-algorithm
// https://codereview.stackexchange.com/questions/32024/calculating-gethashcode-efficiently-with-unordered-list
// we use the simple JDK approach here
var elements = 0;
unchecked
{
foreach (var item in obj)
{
if (item is not null)
elements += _comparer.GetHashCode(item);
}
}
code.Add(elements);
return code.ToHashCode();
}
public int Compare(IReadOnlyCollection<T>? x, IReadOnlyCollection<T>? y)
{
if (ReferenceEquals(x, y)) return 0;
if (ReferenceEquals(null, y)) return 1;
if (ReferenceEquals(null, x)) return -1;
var comp = x.Count.CompareTo(y.Count);
if (comp != 0) return comp;
var orderedX = GetOrdered(x);
var orderedY = GetOrdered(y);
for (int i = 0; i < orderedX.Count; i++)
{
comp = _comparer.Compare(orderedX[i], orderedY[i]);
if (comp != 0) return comp;
}
return comp;
}
}
public class KeyValueComparer<TKey, TValue> : IFullComparer<KeyValuePair<TKey, TValue>>
{
private readonly IFullComparer<TKey> _keyComparer;
private readonly IFullComparer<TValue> _valueComparer;
public KeyValueComparer(IFullComparer<TKey> keyComparer, IFullComparer<TValue> valueComparer)
{
_keyComparer = keyComparer;
_valueComparer = valueComparer;
}
public bool Equals(KeyValuePair<TKey, TValue> x, KeyValuePair<TKey, TValue> y)
{
return _keyComparer.Equals(x.Key, y.Key) && _valueComparer.Equals(x.Value, y.Value);
}
public int GetHashCode(KeyValuePair<TKey, TValue> obj)
{
var code = new HashCode();
code.Add(obj.Key, _keyComparer);
code.Add(obj.Value, _valueComparer);
return code.ToHashCode();
}
public int Compare(KeyValuePair<TKey, TValue> x, KeyValuePair<TKey, TValue> y)
{
var keyComp = _keyComparer.Compare(x.Key, y.Key);
if (keyComp != 0) return keyComp;
keyComp = _valueComparer.Compare(x.Value, y.Value);
return keyComp;
}
}
public static class CollectionComparion
{
public static IReadOnlyDictionary<TKey, TValue> MakeComparable<TKey, TValue>(
this IReadOnlyDictionary<TKey, TValue> sourceDictionary, IFullComparer<KeyValuePair<TKey, TValue>>? itemComparer = null)
{
if (sourceDictionary is ComparableDictionary<TKey, TValue> comp && ReferenceEquals(comp._fullComparer, ComparableDictionary<TKey, TValue>.DefaultComparer) && itemComparer == null)
{
return comp;
}
return new ComparableDictionary<TKey, TValue>(sourceDictionary, itemComparer);
}
private class ComparableDictionary<TKey, TValue> : IReadOnlyDictionary<TKey, TValue>
{
private readonly IReadOnlyDictionary<TKey, TValue> _sourceDictionary;
internal readonly IFullComparer<IReadOnlyCollection<KeyValuePair<TKey, TValue>>> _fullComparer;
private static IFullComparer<KeyValuePair<TKey, TValue>>? _defaultItemComparer;
private static IFullComparer<KeyValuePair<TKey, TValue>> DefaultItemComparer => _defaultItemComparer ??= new KeyValueComparer<TKey, TValue>(
FullComparer<TKey>.Default, FullComparer<TValue>.Default);
private static IFullComparer<IReadOnlyCollection<KeyValuePair<TKey, TValue>>>? _defaultComparer;
internal static IFullComparer<IReadOnlyCollection<KeyValuePair<TKey, TValue>>> DefaultComparer => _defaultComparer ??=
new UnorderedCollectionComparer<KeyValuePair<TKey, TValue>>(DefaultItemComparer);
public ComparableDictionary(IReadOnlyDictionary<TKey, TValue> sourceDictionary, IFullComparer<KeyValuePair<TKey, TValue>>? itemComparer = null)
: this(sourceDictionary, itemComparer == null ? DefaultComparer : new UnorderedCollectionComparer<KeyValuePair<TKey, TValue>>(itemComparer))
{
}
protected ComparableDictionary(IReadOnlyDictionary<TKey, TValue> sourceDictionary, IFullComparer<IReadOnlyCollection<KeyValuePair<TKey, TValue>>> fullComparer)
{
_sourceDictionary = sourceDictionary;
_fullComparer = fullComparer;
}
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
return _sourceDictionary.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)_sourceDictionary).GetEnumerator();
}
public int Count => _sourceDictionary.Count;
public bool ContainsKey(TKey key)
{
return _sourceDictionary.ContainsKey(key);
}
public bool TryGetValue(TKey key, out TValue value)
{
return _sourceDictionary.TryGetValue(key, out value);
}
public TValue this[TKey key] => _sourceDictionary[key];
public IEnumerable<TKey> Keys => _sourceDictionary.Keys;
public IEnumerable<TValue> Values => _sourceDictionary.Values;
protected bool Equals(ComparableDictionary<TKey, TValue> other)
{
return _fullComparer.Equals(this._sourceDictionary, other._sourceDictionary);
}
public override bool Equals(object? obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
return Equals((ComparableDictionary<TKey, TValue>)obj);
}
public override int GetHashCode()
{
return _fullComparer.GetHashCode(_sourceDictionary);
}
}
public static IImmutableDictionary<TKey, TValue> MakeComparable<TKey, TValue>(
this IImmutableDictionary<TKey, TValue> sourceDictionary, IFullComparer<KeyValuePair<TKey, TValue>>? itemComparer = null)
{
if (sourceDictionary is ComparableImmutableDictionary<TKey, TValue> comp && ReferenceEquals(comp._fullComparer, ComparableImmutableDictionary<TKey, TValue>.DefaultComparer) && itemComparer == null)
{
return comp;
}
return new ComparableImmutableDictionary<TKey, TValue>(sourceDictionary, itemComparer);
}
private class ComparableImmutableDictionary<TKey, TValue> : ComparableDictionary<TKey, TValue>, IImmutableDictionary<TKey, TValue>
{
private readonly IImmutableDictionary<TKey, TValue> _sourceDictionary;
public ComparableImmutableDictionary(IImmutableDictionary<TKey, TValue> sourceDictionary, IFullComparer<KeyValuePair<TKey, TValue>>? itemComparer = null)
: this(sourceDictionary, itemComparer == null ? DefaultComparer : new UnorderedCollectionComparer<KeyValuePair<TKey, TValue>>(itemComparer))
{
}
protected ComparableImmutableDictionary(IImmutableDictionary<TKey, TValue> sourceDictionary, IFullComparer<IReadOnlyCollection<KeyValuePair<TKey, TValue>>> fullComparer)
:base(sourceDictionary, fullComparer)
{
_sourceDictionary = sourceDictionary;
}
private IImmutableDictionary<TKey, TValue> Wrap(IImmutableDictionary<TKey, TValue> newDict)
{
return new ComparableImmutableDictionary<TKey, TValue>(newDict, _fullComparer);
}
public IImmutableDictionary<TKey, TValue> Add(TKey key, TValue value)
{
return Wrap(_sourceDictionary.Add(key, value));
}
public IImmutableDictionary<TKey, TValue> AddRange(IEnumerable<KeyValuePair<TKey, TValue>> pairs)
{
return Wrap(_sourceDictionary.AddRange(pairs));
}
public IImmutableDictionary<TKey, TValue> Clear()
{
return Wrap(_sourceDictionary.Clear());
}
public bool Contains(KeyValuePair<TKey, TValue> pair)
{
return _sourceDictionary.Contains(pair);
}
public IImmutableDictionary<TKey, TValue> Remove(TKey key)
{
return Wrap(_sourceDictionary.Remove(key));
}
public IImmutableDictionary<TKey, TValue> RemoveRange(IEnumerable<TKey> keys)
{
return Wrap(_sourceDictionary.RemoveRange(keys));
}
public IImmutableDictionary<TKey, TValue> SetItem(TKey key, TValue value)
{
return Wrap(_sourceDictionary.SetItem(key, value));
}
public IImmutableDictionary<TKey, TValue> SetItems(IEnumerable<KeyValuePair<TKey, TValue>> items)
{
return Wrap(_sourceDictionary.SetItems(items));
}
public bool TryGetKey(TKey equalKey, out TKey actualKey)
{
return _sourceDictionary.TryGetKey(equalKey, out actualKey);
}
}
public static IReadOnlyCollection<T> MakeComparable<T>(
this IReadOnlyCollection<T> sourceCollection,
IFullComparer<T>? itemComparer = null)
{
if (sourceCollection is ComparableCollection<T> comp && ReferenceEquals(comp._fullComparer, ComparableCollection<T>.DefaultComparer) && itemComparer == null)
{
return comp;
}
return new ComparableCollection<T>(sourceCollection);
}
public static IReadOnlyList<T> MakeComparable<T>(
this IReadOnlyList<T> sourceCollection,
IFullComparer<T>? itemComparer = null)
{
if (sourceCollection is ComparableList<T> comp && ReferenceEquals(comp._fullComparer, ComparableList<T>.DefaultComparer) && itemComparer == null)
{
return comp;
}
return new ComparableList<T>(sourceCollection);
}
public static IImmutableList<T> MakeComparable<T>(
this IImmutableList<T> sourceCollection,
IFullComparer<T>? itemComparer = null)
{
if (sourceCollection is ComparableImmutableList<T> comp && ReferenceEquals(comp._fullComparer, ComparableImmutableList<T>.DefaultComparer) && itemComparer == null)
{
return comp;
}
return new ComparableImmutableList<T>(sourceCollection);
}
private class ComparableCollection<T> : IReadOnlyCollection<T>
{
private readonly IReadOnlyCollection<T> _sourceCollection;
internal readonly IFullComparer<IReadOnlyCollection<T>> _fullComparer;
private static IFullComparer<IReadOnlyCollection<T>>? _defaultComparer;
internal static IFullComparer<IReadOnlyCollection<T>> DefaultComparer => _defaultComparer ??=
new UnorderedCollectionComparer<T>(FullComparer<T>.Default);
public ComparableCollection(IReadOnlyCollection<T> sourceCollection, IFullComparer<T>? itemComparer = null)
: this(sourceCollection, itemComparer == null ? DefaultComparer : new UnorderedCollectionComparer<T>(itemComparer))
{
}
protected ComparableCollection(IReadOnlyCollection<T> sourceCollection, IFullComparer<IReadOnlyCollection<T>> fullComparer)
{
_sourceCollection = sourceCollection;
_fullComparer = fullComparer;
}
public IEnumerator<T> GetEnumerator()
{
return _sourceCollection.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)_sourceCollection).GetEnumerator();
}
public int Count => _sourceCollection.Count;
protected bool Equals(ComparableCollection<T> other)
{
return _fullComparer.Equals(this._sourceCollection, other._sourceCollection);
}
public override bool Equals(object? obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
return Equals((ComparableCollection<T>)obj);
}
public override int GetHashCode()
{
return _fullComparer.GetHashCode(_sourceCollection);
}
}
private class ComparableList<T> : ComparableCollection<T>, IReadOnlyList<T>
{
private readonly IReadOnlyList<T> _sourceList;
public ComparableList(IReadOnlyList<T> sourceList, IFullComparer<T>? itemComparer = null)
: base(sourceList, itemComparer)
{
_sourceList = sourceList;
}
protected ComparableList(IReadOnlyList<T> sourceList, IFullComparer<IReadOnlyCollection<T>> fullComparer)
: base(sourceList, fullComparer)
{
_sourceList = sourceList;
}
public T this[int index] => _sourceList[index];
}
private class ComparableImmutableList<T> : ComparableList<T>, IImmutableList<T>
{
private readonly IImmutableList<T> _sourceList;
public ComparableImmutableList(IImmutableList<T> sourceList, IFullComparer<T>? itemComparer = null)
: base(sourceList, itemComparer)
{
_sourceList = sourceList;
}
protected ComparableImmutableList(IImmutableList<T> sourceList, IFullComparer<IReadOnlyCollection<T>> fullComparer)
: base(sourceList, fullComparer)
{
_sourceList = sourceList;
}
private IImmutableList<T> Wrap(IImmutableList<T> newDict)
{
return new ComparableImmutableList<T>(newDict, _fullComparer);
}
public IImmutableList<T> Add(T value)
{
return Wrap(_sourceList.Add(value));
}
public IImmutableList<T> AddRange(IEnumerable<T> items)
{
return Wrap(_sourceList.AddRange(items));
}
public IImmutableList<T> Clear()
{
return Wrap(_sourceList.Clear());
}
public int IndexOf(T item, int index, int count, IEqualityComparer<T>? equalityComparer)
{
return _sourceList.IndexOf(item, index, count, equalityComparer);
}
public IImmutableList<T> Insert(int index, T element)
{
return Wrap(_sourceList.Insert(index, element));
}
public IImmutableList<T> InsertRange(int index, IEnumerable<T> items)
{
return Wrap(_sourceList.InsertRange(index, items));
}
public int LastIndexOf(T item, int index, int count, IEqualityComparer<T>? equalityComparer)
{
return _sourceList.LastIndexOf(item, index, count, equalityComparer);
}
public IImmutableList<T> Remove(T value, IEqualityComparer<T>? equalityComparer)
{
return Wrap(_sourceList.Remove(value, equalityComparer));
}
public IImmutableList<T> RemoveAll(Predicate<T> match)
{
return Wrap(_sourceList.RemoveAll(match));
}
public IImmutableList<T> RemoveAt(int index)
{
return Wrap(_sourceList.RemoveAt(index));
}
public IImmutableList<T> RemoveRange(IEnumerable<T> items, IEqualityComparer<T>? equalityComparer)
{
return Wrap(_sourceList.RemoveRange(items, equalityComparer));
}
public IImmutableList<T> RemoveRange(int index, int count)
{
return Wrap(_sourceList.RemoveRange(index, count));
}
public IImmutableList<T> Replace(T oldValue, T newValue, IEqualityComparer<T>? equalityComparer)
{
return Wrap(_sourceList.Replace(oldValue, newValue, equalityComparer));
}
public IImmutableList<T> SetItem(int index, T value)
{
return Wrap(_sourceList.SetItem(index, value));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment