Skip to content

Instantly share code, notes, and snippets.

Created January 16, 2016 21:36
Show Gist options
  • Save madelson/9673c51c0f3fbfd6e79b to your computer and use it in GitHub Desktop.
Save madelson/9673c51c0f3fbfd6e79b to your computer and use it in GitHub Desktop.
Demonstrates the output of rewriting the MedallionCollections source for us in an inline NuGet package
// PACKAGE MedallionCollections.Inline 1.0.1
// The code in this file was AUTO-GENERATED by installing the MedallionCollections.Inline NuGet package.
// To update, run Update-Package MedallionCollections.Inline in the NuGet package manager console.
// You can modify this file without changing its source by setting
// preprocessor directives referenced here in your project properties
using System.Collections.Generic;
using System.Collections;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System;
#if MedallionCollections_USE_LOCAL_NAMESPACE
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Medallion.Tools.InlineNuGet", ""), global::System.Diagnostics.DebuggerNonUserCodeAttribute]
#if MedallionCollections_PUBLIC
static partial class CollectionHelper
#region ---- Partition ----
/// <summary>
/// Splits the given <paramref name="source"/> sequence into a series of <see cref="List{T}"/>s
/// of length <paramref name="partitionSize"/>. Note that the final partition may be less than
/// <paramref name="partitionSize"/>
/// </summary>
public static IEnumerable<List<T>> Partition<T>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<T> source, int partitionSize)
if (source == null) { throw new ArgumentNullException("source"); }
if (partitionSize < 1) { throw new ArgumentOutOfRangeException(paramName: "partitionSize", message: string.Format("Value must be positive (got {0})", partitionSize)); }
return PartitionIterator(source, partitionSize);
private static IEnumerable<List<T>> PartitionIterator<T>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<T> source, int partitionSize)
// we like initializing our lists with capacity to avoid resizes. However, we don't want to trigger
// OutOfMemory if the partition size is huge
var initialCapacity = Math.Min(partitionSize, 1024);
using (var enumerator = source.GetEnumerator())
while (enumerator.MoveNext())
var partition = new List<T>(capacity: initialCapacity) { enumerator.Current };
for (var i = 1; i < partitionSize && enumerator.MoveNext(); ++i)
yield return partition;
#region ---- Append ----
/// <summary>
/// As <see cref="Enumerable.Concat{TSource}(IEnumerable{TSource}, IEnumerable{TSource})"/>, but with better
/// performance for repeated calls. See
/// </summary>
public static IEnumerable<TElement> Append<TElement>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<TElement> first, IEnumerable<TElement> second)
if (first == null) { throw new ArgumentNullException("first"); }
if (second == null) { throw new ArgumentNullException("second"); }
return new AppendEnumerable<TElement>(first, second);
/// <summary>
/// As <see cref="Enumerable.Concat{TSource}(IEnumerable{TSource}, IEnumerable{TSource})"/>, but appends only one element
/// Optimized for repeated calls. See
/// </summary>
public static IEnumerable<TElement> Append<TElement>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<TElement> sequence, TElement next)
if (sequence == null) { throw new ArgumentNullException("sequence"); }
return new AppendOneEnumerable<TElement>(sequence, next);
/// <summary>
/// As <see cref="Append{TElement}(IEnumerable{TElement}, IEnumerable{TElement})"/>, but prepends the elements
/// instead
/// </summary>
public static IEnumerable<TElement> Prepend<TElement>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<TElement> second, IEnumerable<TElement> first)
if (first == null) { throw new ArgumentNullException("first"); }
if (second == null) { throw new ArgumentNullException("second"); }
return new AppendEnumerable<TElement>(first, second);
/// <summary>
/// As <see cref="Append{TElement}(IEnumerable{TElement}, TElement)"/>, but prepends an element instead
/// </summary>
public static IEnumerable<TElement> Prepend<TElement>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<TElement> sequence, TElement previous)
if (sequence == null) { throw new ArgumentNullException("sequence"); }
return new PrependOneEnumerable<TElement>(previous, sequence);
private interface IAppendEnumerable<out TElement>
IEnumerable<TElement> PreviousElements { get; }
TElement PreviousElement { get; }
IEnumerable<TElement> NextElements { get; }
TElement NextElement { get; }
private abstract class AppendEnumerableBase<TElement> : IAppendEnumerable<TElement>, IEnumerable<TElement>
public abstract TElement NextElement { get; }
public abstract IEnumerable<TElement> NextElements { get; }
public abstract TElement PreviousElement { get; }
public abstract IEnumerable<TElement> PreviousElements { get; }
IEnumerator IEnumerable.GetEnumerator()
return this.AsEnumerable().GetEnumerator();
IEnumerator<TElement> IEnumerable<TElement>.GetEnumerator()
// we special case the basic case so that it doesn't even need to create the stack
if (!(this.PreviousElements is IAppendEnumerable<TElement>)
&& !(this.NextElements is IAppendEnumerable<TElement>))
if (this.PreviousElements != null)
foreach (var element in this.PreviousElements)
yield return element;
yield return this.PreviousElement;
if (this.NextElements != null)
foreach (var element in this.NextElements)
yield return element;
yield return this.NextElement;
// the algorithm here keeps 2 pieces of state:
// (1) the current node in the append enumerable binary tree
// (2) a stack of nodes we have to come back to to process the right subtree (nexts)
// the steps are as follows, starting with current as the root of the tree:
// (1) if the left subtree is a leaf, yield it
// (2) otherwise, push current on the stack and set current = the left subtree
// (3) if the right subtree is a leaf, yield it
// (4) otherwise, set current = right subtree
// (5) if both subtrees were leaves, set current = stack.Pop(), or exit if stack is empty
IAppendEnumerable<TElement> currentAppendEnumerable = this;
var enumerableStack = new Stack<IAppendEnumerable<TElement>>();
while (true)
if (currentAppendEnumerable != null)
var previous = currentAppendEnumerable.PreviousElements;
if (previous == null)
yield return currentAppendEnumerable.PreviousElement;
var previousAppendEnumerable = previous as IAppendEnumerable<TElement>;
if (previousAppendEnumerable != null)
currentAppendEnumerable = previousAppendEnumerable;
foreach (var previousElement in currentAppendEnumerable.PreviousElements)
yield return previousElement;
if (currentAppendEnumerable == null)
if (enumerableStack.Count == 0)
yield break;
currentAppendEnumerable = enumerableStack.Pop();
var next = currentAppendEnumerable.NextElements;
if (next == null)
yield return currentAppendEnumerable.NextElement;
var nextAppendEnumerable = currentAppendEnumerable.NextElements as IAppendEnumerable<TElement>;
if (nextAppendEnumerable != null)
currentAppendEnumerable = nextAppendEnumerable;
foreach (var nextElement in currentAppendEnumerable.NextElements)
yield return nextElement;
currentAppendEnumerable = null;
private sealed class AppendEnumerable<TElement> : AppendEnumerableBase<TElement>
private readonly IEnumerable<TElement> previous, next;
public AppendEnumerable(IEnumerable<TElement> previous, IEnumerable<TElement> next)
this.previous = previous; = next;
public override TElement NextElement { get { throw new InvalidOperationException(); } }
public override IEnumerable<TElement> NextElements { get { return; } }
public override TElement PreviousElement { get { throw new InvalidOperationException(); } }
public override IEnumerable<TElement> PreviousElements { get { return this.previous; } }
private sealed class AppendOneEnumerable<TElement> : AppendEnumerableBase<TElement>
private readonly IEnumerable<TElement> previous;
private readonly TElement next;
public AppendOneEnumerable(IEnumerable<TElement> previous, TElement next)
this.previous = previous; = next;
public override TElement NextElement { get { return; } }
public override IEnumerable<TElement> NextElements { get { return null; } }
public override TElement PreviousElement { get { throw new InvalidOperationException(); } }
public override IEnumerable<TElement> PreviousElements { get { return this.previous; } }
private sealed class PrependOneEnumerable<TElement> : AppendEnumerableBase<TElement>
private readonly TElement previous;
private readonly IEnumerable<TElement> next;
public PrependOneEnumerable(TElement previous, IEnumerable<TElement> next)
this.previous = previous; = next;
public override TElement NextElement { get { throw new InvalidOperationException(); } }
public override IEnumerable<TElement> NextElements { get { return; } }
public override TElement PreviousElement { get { return this.previous; ; } }
public override IEnumerable<TElement> PreviousElements { get { return null; } }
#region ---- MaxBy / MinBy ----
/// <summary>
/// As <see cref="Enumerable.Max{TSource, TResult}(IEnumerable{TSource}, Func{TSource, TResult})"/>, but returns the
/// maximum item from the original sequence instead of the value projected by <paramref name="keySelector"/>. The
/// optional <paramref name="comparer"/> allows key comparisons to be specified
/// </summary>
public static TSource MaxBy<TSource, TKey>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey> comparer = null)
if (source == null) { throw new ArgumentNullException("source"); }
if (keySelector == null) { throw new ArgumentNullException("keySelector"); }
var cmp = comparer ?? Comparer<TKey>.Default;
using (var enumerator = source.GetEnumerator())
var isNullable = default(TSource) == null;
if (!enumerator.MoveNext())
// just like native Min/Max, the empty sequence returns null for nullable types
// and throws hard for non-nullable types
if (isNullable) { return default(TSource); }
throw new InvalidOperationException("Sequence contains no elements");
var bestValue = enumerator.Current;
var bestKey = keySelector(bestValue);
while (enumerator.MoveNext())
var value = enumerator.Current;
var key = keySelector(value);
if (isNullable
// like Min/Max, nulls are excluded from the comparison
? (bestKey == null || (cmp.Compare(key, bestKey) > 0 && key != null))
: cmp.Compare(key, bestKey) > 0)
bestValue = value;
bestKey = key;
return bestValue;
/// <summary>
/// As <see cref="Enumerable.Min{TSource, TResult}(IEnumerable{TSource}, Func{TSource, TResult})"/>, but returns the
/// minimum item from the original sequence instead of the value projected by <paramref name="keySelector"/>. The
/// optional <paramref name="comparer"/> allows key comparisons to be specified
/// </summary>
public static TSource MinBy<TSource, TKey>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey> comparer = null)
if (source == null) { throw new ArgumentNullException("source"); }
if (keySelector == null) { throw new ArgumentNullException("keySelector"); }
var cmp = comparer ?? Comparer<TKey>.Default;
using (var enumerator = source.GetEnumerator())
var isNullable = default(TSource) == null;
if (!enumerator.MoveNext())
// just like native Min/Max, the empty sequence returns null for nullable types
// and throws hard for non-nullable types
if (isNullable) { return default(TSource); }
throw new InvalidOperationException("Sequence contains no elements");
var bestValue = enumerator.Current;
var bestKey = keySelector(bestValue);
while (enumerator.MoveNext())
var value = enumerator.Current;
var key = keySelector(value);
if (isNullable
// like Min/Max, nulls are excluded from the comparison
? (bestKey == null || (cmp.Compare(key, bestKey) < 0 && key != null))
: cmp.Compare(key, bestKey) < 0)
bestValue = value;
bestKey = key;
return bestValue;
#region ---- CollectionEquals ----
/// <summary>
/// Determines whether <paramref name="this"/> and <paramref name="that"/> are equal in the sense of having the exact same
/// elements. Unlike <see cref="Enumerable.SequenceEqual{TSource}(IEnumerable{TSource}, IEnumerable{TSource})"/>,
/// this method disregards order. Unlike <see cref="ISet{T}.SetEquals(IEnumerable{T})"/>, this method does not disregard duplicates.
/// An optional <paramref name="comparer"/> allows the equality semantics for the elements to be specified
/// </summary>
public static bool CollectionEquals<TElement>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IEnumerable<TElement> @this, IEnumerable<TElement> that, IEqualityComparer<TElement> comparer = null)
if (@this == null) { throw new ArgumentNullException("this"); }
if (that == null) { throw new ArgumentNullException("that"); }
// FastCount optimization: If both of the collections are materialized and have counts,
// we can exit very quickly if those counts differ
int thisCount, thatCount;
var hasThisCount = TryFastCount(@this, out thisCount);
bool hasThatCount;
if (hasThisCount)
hasThatCount = TryFastCount(that, out thatCount);
if (hasThatCount)
if (thisCount != thatCount)
return false;
if (thisCount == 0)
return true;
hasThatCount = false;
var cmp = comparer ?? EqualityComparer<TElement>.Default;
var itemsEnumerated = 0;
// SequenceEqual optimization: we reduce/avoid hashing
// the collections have common prefixes, at the cost of only one
// extra Equals() call in the case where the prefixes are not common
using (var thisEnumerator = @this.GetEnumerator())
using (var thatEnumerator = that.GetEnumerator())
while (true)
var thisFinished = !thisEnumerator.MoveNext();
var thatFinished = !thatEnumerator.MoveNext();
if (thisFinished)
// either this shorter than that, or the two were sequence-equal
return thatFinished;
if (thatFinished)
// that shorter than this
return false;
// keep track of this so that we can factor it into count-based
// logic below
if (!cmp.Equals(thisEnumerator.Current, thatEnumerator.Current))
break; // prefixes were not equal
// now, build a dictionary of item => count out of one collection and then
// probe it with the other collection to look for mismatches
// Build/Probe Choice optimization: if we know the count of one collection, we should
// use the other collection to build the dictionary. That way we can bail immediately if
// we see too few or too many items
CountingSet<TElement> elementCounts;
IEnumerator<TElement> probeSide;
if (hasThisCount)
// we know this's count => use that as the build side
probeSide = thisEnumerator;
var remaining = thisCount - itemsEnumerated;
if (hasThatCount)
// if we have both counts, that means they must be equal or we would have already
// exited. However, in this case, we know exactly the capacity needed for the dictionary
// so we can avoid resizing
elementCounts = new CountingSet<TElement>(capacity: remaining, comparer: cmp);
while (thatEnumerator.MoveNext());
elementCounts = TryBuildElementCountsWithKnownCount(thatEnumerator, remaining, cmp);
else if (TryFastCount(that, out thatCount))
// we know that's count => use this as the build side
probeSide = thatEnumerator;
var remaining = thatCount - itemsEnumerated;
elementCounts = TryBuildElementCountsWithKnownCount(thisEnumerator, remaining, cmp);
// when we don't know either count, just use that as the build side arbitrarily
probeSide = thisEnumerator;
elementCounts = new CountingSet<TElement>(cmp);
while (thatEnumerator.MoveNext());
// check whether we failed to construct a dictionary. This happens when we know
// one of the counts and we detect, during construction, that the counts are unequal
if (elementCounts == null)
return false;
// probe the dictionary with the probe side enumerator
if (!elementCounts.TryDecrement(probeSide.Current))
// element in probe not in build => not equal
return false;
while (probeSide.MoveNext());
// we are equal only if the loop above completely cleared out the dictionary
return elementCounts.IsEmpty;
/// <summary>
/// Constructs a count dictionary, staying mindful of the known number of elements
/// so that we bail early (returning null) if we detect a count mismatch
/// </summary>
private static CountingSet<TKey> TryBuildElementCountsWithKnownCount<TKey>(
IEnumerator<TKey> elements,
int remaining,
IEqualityComparer<TKey> comparer)
if (remaining == 0)
// don't build the dictionary at all if nothing should be in it
return null;
const int MaxInitialElementCountsCapacity = 1024;
var elementCounts = new CountingSet<TKey>(capacity: Math.Min(remaining, MaxInitialElementCountsCapacity), comparer: comparer);
while (elements.MoveNext())
if (--remaining < 0)
// too many elements
return null;
if (remaining > 0)
// too few elements
return null;
return elementCounts;
/// <summary>
/// Key Lookup Reduction optimization: this custom datastructure halves the number of <see cref="IEqualityComparer{T}.GetHashCode(T)"/>
/// and <see cref="IEqualityComparer{T}.Equals(T, T)"/> operations by building in the increment/decrement operations of a counting dictionary.
/// This also solves <see cref="Dictionary{TKey, TValue}"/>'s issues with null keys
/// </summary>
private sealed class CountingSet<T>
// picked based on observing unit test performance
private const double MaxLoad = .62;
private readonly IEqualityComparer<T> comparer;
private Bucket[] buckets;
private int populatedBucketCount;
/// <summary>
/// When we reach this count, we need to resize
/// </summary>
private int nextResizeCount;
public CountingSet(IEqualityComparer<T> comparer, int capacity = 0)
this.comparer = comparer;
// we pick the initial length by assuming our current table is one short of the desired
// capacity and then using our standard logic of picking the next valid table size
this.buckets = new Bucket[GetNextTableSize((int)(capacity / MaxLoad) - 1)];
this.nextResizeCount = this.CalculateNextResizeCount();
public bool IsEmpty { get { return this.populatedBucketCount == 0; } }
public void Increment(T item)
int bucketIndex;
uint hashCode;
if (this.TryFindBucket(item, out bucketIndex, out hashCode))
// if a bucket already existed, just update it's count
// otherwise, claim a new bucket
this.buckets[bucketIndex].HashCode = hashCode;
this.buckets[bucketIndex].Value = item;
this.buckets[bucketIndex].Count = 1;
// resize the table if we've grown too full
if (this.populatedBucketCount == this.nextResizeCount)
var newBuckets = new Bucket[GetNextTableSize(this.buckets.Length)];
// rehash
for (var i = 0; i < this.buckets.Length; ++i)
var oldBucket = this.buckets[i];
if (oldBucket.HashCode != 0)
var newBucketIndex = oldBucket.HashCode % newBuckets.Length;
while (true)
if (newBuckets[newBucketIndex].HashCode == 0)
newBuckets[newBucketIndex] = oldBucket;
newBucketIndex = (newBucketIndex + 1) % newBuckets.Length;
this.buckets = newBuckets;
this.nextResizeCount = this.CalculateNextResizeCount();
public bool TryDecrement(T item)
int bucketIndex;
uint ignored;
if (this.TryFindBucket(item, out bucketIndex, out ignored)
&& this.buckets[bucketIndex].Count > 0)
if (--this.buckets[bucketIndex].Count == 0)
// Note: we can't do this because it messes up our try-find logic
//// mark as unpopulated. Not strictly necessary because CollectionEquals always does all increments
//// before all decrements currently. However, this is very cheap to do and allowing the collection to
//// "just work" in any situation is a nice benefit
//// this.buckets[bucketIndex].HashCode = 0;
return true;
return false;
private bool TryFindBucket(T item, out int index, out uint hashCode)
// we convert the raw hash code to a uint to get correctly-signed mod operations
// and get rid of the zero value so that we can use 0 to mean "unoccupied"
var rawHashCode = this.comparer.GetHashCode(item);
hashCode = rawHashCode == 0 ? uint.MaxValue : unchecked((uint)rawHashCode);
var bestBucketIndex = (int)(hashCode % this.buckets.Length);
var bucketIndex = bestBucketIndex;
while (true) // guaranteed to terminate because of how we set load factor
var bucket = this.buckets[bucketIndex];
if (bucket.HashCode == 0)
// found unoccupied bucket
index = bucketIndex;
return false;
if (bucket.HashCode == hashCode && this.comparer.Equals(bucket.Value, item))
// found matching bucket
index = bucketIndex;
return true;
// otherwise march on to the next adjacent bucket
bucketIndex = (bucketIndex + 1) % this.buckets.Length;
private int CalculateNextResizeCount()
return (int)(MaxLoad * this.buckets.Length) + 1;
private static readonly int[] HashTableSizes = new[]
// hash table primes from
23, 53, 97, 193, 389, 769, 1543, 3079, 6151, 12289,
24593, 49157, 98317, 196613, 393241, 786433, 1572869,
3145739, 6291469, 12582917, 25165843, 50331653, 100663319,
201326611, 402653189, 805306457, 1610612741,
// the first two values are (1) a prime roughly half way between the previous value and int.MaxValue
// and (2) the prime closest too, but not above, int.MaxValue. The maximum size is, of course, int.MaxValue
1879048201, 2147483629, int.MaxValue
private static int GetNextTableSize(int currentSize)
for (var i = 0; i < HashTableSizes.Length; ++i)
var nextSize = HashTableSizes[i];
if (nextSize > currentSize) { return nextSize; }
throw new InvalidOperationException("Hash table cannot expand further");
[DebuggerDisplay("{Value}, {Count}, {HashCode}")]
private struct Bucket
// note: 0 (default) means the bucket is unoccupied
internal uint HashCode;
internal T Value;
internal int Count;
#region ---- GetOrAdd ----
/// <summary>
/// If <paramref name="key"/> exists in <paramref name="dictionary"/>, returns the associated value. Otherwise,
/// generates a new value by applying <paramref name="valueFactory"/> to the given <paramref name="key"/>. The
/// new value is stored in <paramref name="dictionary"/> and returned
/// </summary>
public static TValue GetOrAdd<TKey, TValue>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IDictionary<TKey, TValue> dictionary, TKey key, Func<TKey, TValue> valueFactory)
if (dictionary == null) { throw new ArgumentNullException("dictionary"); }
if (valueFactory == null) { throw new ArgumentNullException("valueFactory"); }
TValue existing;
if (dictionary.TryGetValue(key, out existing))
return existing;
var value = valueFactory(key);
dictionary.Add(key, value);
return value;
private static bool TryFastCount<T>(IEnumerable<T> @this, out int count)
var collection = @this as ICollection<T>;
if (collection != null)
count = collection.Count;
return true;
var readOnlyCollection = @this as IReadOnlyCollection<T>;
if (readOnlyCollection != null)
count = readOnlyCollection.Count;
return true;
count = -1;
return false;
#if MedallionCollections_USE_LOCAL_NAMESPACE
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Medallion.Tools.InlineNuGet", ""), global::System.Diagnostics.DebuggerNonUserCodeAttribute]
#if MedallionCollections_PUBLIC
static partial class Comparers
#region ---- Key Comparer ----
/// <summary>
/// Creates a <see cref="Comparer{T}"/> which compares values of type <typeparamref name="T"/> by
/// projecting them to type <typeparamref name="TKey"/> using the given <paramref name="keySelector"/>.
/// The optional <paramref name="keyComparer"/> determines how keys are compared
/// </summary>
public static Comparer<T> Create<T, TKey>(Func<T, TKey> keySelector, IComparer<TKey> keyComparer = null)
if (keySelector == null) { throw new ArgumentNullException("keySelector"); }
return new KeyComparer<T, TKey>(keySelector, keyComparer ?? Comparer<TKey>.Default);
private sealed class KeyComparer<T, TKey> : Comparer<T>
private readonly Func<T, TKey> keySelector;
private readonly IComparer<TKey> keyComparer;
public KeyComparer(Func<T, TKey> keySelector, IComparer<TKey> keyComparer)
this.keySelector = keySelector;
this.keyComparer = keyComparer;
public override int Compare(T x, T y)
// from Comparer<T>.Compare(object, object)
if (x == null)
return y == null ? 0 : -1;
if (y == null)
return 1;
return this.keyComparer.Compare(this.keySelector(x), this.keySelector(y));
public override bool Equals(object obj)
if (ReferenceEquals(obj, this)) { return true; }
var that = obj as KeyComparer<T, TKey>;
return that != null
&& that.keySelector.Equals(this.keySelector)
&& that.keyComparer.Equals(this.keyComparer);
public override int GetHashCode()
return unchecked((3 * this.keySelector.GetHashCode()) + this.keyComparer.GetHashCode());
#region ---- Reverse ----
/// <summary>
/// Gets an <see cref="IComparer{T}"/> which represents the reverse of
/// the order implied by <see cref="Comparer{T}.Default"/>
/// </summary>
public static IComparer<T> Reverse<T>()
return ReverseComparer<T>.Default;
/// <summary>
/// Gets an <see cref="IComparer{T}"/> which represents the reverse of
/// the order implied by the given <paramref name="comparer"/>
/// </summary>
public static IComparer<T> Reverse<T>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IComparer<T> comparer)
if (comparer == null) { throw new ArgumentNullException("comparer"); }
return comparer == Comparer<T>.Default
? Reverse<T>()
: new ReverseComparer<T>(comparer);
// we don't want Comparer<T> here because that doesn't let us override
// the comparison of nulls in the Compare(object, object) method
private sealed class ReverseComparer<T> : IComparer<T>, IComparer
public static readonly ReverseComparer<T> Default = new ReverseComparer<T>(Comparer<T>.Default);
private readonly IComparer<T> comparer;
public ReverseComparer(IComparer<T> comparer)
this.comparer = comparer;
public int Compare(T x, T y)
return this.comparer.Compare(y, x);
int IComparer.Compare(object x, object y)
return this.Compare((T)x, (T)y);
public override bool Equals(object obj)
if (ReferenceEquals(obj, this)) { return true; }
var that = obj as ReverseComparer<T>;
return that != null && that.comparer.Equals(this.comparer);
public override int GetHashCode()
return ReferenceEquals(this, Default)
? base.GetHashCode()
: unchecked((3 * Default.GetHashCode()) + this.comparer.GetHashCode());
#region ---- ThenBy ----
/// <summary>
/// Gets a <see cref="Comparer{T}"/> which compares using <paramref name="first"/>
/// and breaks ties with <paramref name="second"/>
/// </summary>
public static Comparer<T> ThenBy<T>(
#if !MedallionCollections_DISABLE_EXTENSIONS
IComparer<T> first, IComparer<T> second)
if (first == null) { throw new ArgumentNullException("first"); }
if (second == null) { throw new ArgumentNullException("second"); }
return new ThenByComparer<T>(first, second);
private sealed class ThenByComparer<T> : Comparer<T>
private readonly IComparer<T> first, second;
public ThenByComparer(IComparer<T> first, IComparer<T> second)
this.first = first;
this.second = second;
public override int Compare(T x, T y)
var firstComparison = this.first.Compare(x, y);
return firstComparison != 0 ? firstComparison : this.second.Compare(x, y);
public override bool Equals(object obj)
if (ReferenceEquals(obj, this)) { return true; }
var that = obj as ThenByComparer<T>;
return that != null
&& that.first.Equals(this.first)
&& that.second.Equals(this.second);
public override int GetHashCode()
return unchecked((3 * this.first.GetHashCode()) + this.second.GetHashCode());
#region ---- Sequence Comparer ----
/// <summary>
/// Gets a <see cref="Comparer{T}"/> which sorts sequences lexographically. The optional
/// <paramref name="elementComparer"/> can be used to override comparisons of individual elements
/// </summary>
public static Comparer<IEnumerable<T>> GetSequenceComparer<T>(IComparer<T> elementComparer = null)
return elementComparer == null || elementComparer == Comparer<T>.Default
? SequenceComparer<T>.DefaultInstance
: new SequenceComparer<T>(elementComparer);
private sealed class SequenceComparer<T> : Comparer<IEnumerable<T>>
private static Comparer<IEnumerable<T>> defaultInstance;
public static Comparer<IEnumerable<T>> DefaultInstance
return defaultInstance ?? (defaultInstance = new SequenceComparer<T>(Comparer<T>.Default));
private readonly IComparer<T> elementComparer;
public SequenceComparer(IComparer<T> elementComparer)
this.elementComparer = elementComparer;
public override int Compare(IEnumerable<T> x, IEnumerable<T> y)
// from Comparer<T>.Compare(object, object)
if (x == null)
return y == null ? 0 : -1;
if (y == null)
return 1;
if (ReferenceEquals(x, y))
return 0;
using (var xEnumerator = x.GetEnumerator())
using (var yEnumerator = y.GetEnumerator())
while (true)
var xHasMore = xEnumerator.MoveNext();
var yHasMore = yEnumerator.MoveNext();
if (!xHasMore)
return yHasMore ? -1 : 0;
if (!yHasMore)
return 1;
var cmp = this.elementComparer.Compare(xEnumerator.Current, yEnumerator.Current);
if (cmp != 0)
return cmp;
public override bool Equals(object obj)
if (ReferenceEquals(obj, this)) { return true; }
var that = obj as SequenceComparer<T>;
return that != null && that.elementComparer.Equals(this.elementComparer);
public override int GetHashCode()
return ReferenceEquals(this, DefaultInstance)
? base.GetHashCode()
: unchecked((3 * DefaultInstance.GetHashCode()) + this.elementComparer.GetHashCode());
#if MedallionCollections_USE_LOCAL_NAMESPACE
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Medallion.Tools.InlineNuGet", ""), global::System.Diagnostics.DebuggerNonUserCodeAttribute]
#if MedallionCollections_PUBLIC
static partial class Empty
/// <summary>A cached instance of <see cref="IEnumerable"/></summary>
public static IEnumerable ObjectEnumerable
return EmptyCollection<object>.Instance;
/// <summary>A cached readonly instance of <see cref="ICollection"/></summary>
public static ICollection ObjectCollection
return EmptyCollection<object>.Instance;
/// <summary>A cached readonly instance of <see cref="IList"/></summary>
public static IList ObjectList
return EmptyCollection<object>.Instance;
/// <summary>A cached readonly instance of <see cref="IDictionary"/></summary>
public static IDictionary ObjectDictionary
return EmptyDictionary<object, object>.Instance;
/// <summary>A cached instance of <see cref = "IEnumerable{T}"/></summary>
public static IEnumerable<T> Enumerable<T>()
return EmptyCollection<T>.Instance;
}/// <summary>A cached readonly instance of <see cref = "ICollection{T}"/></summary>
public static ICollection<T> Collection<T>()
return EmptyCollection<T>.Instance;
}/// <summary>A cached instance of <see cref = "IReadOnlyCollection{T}"/></summary>
public static IReadOnlyCollection<T> ReadOnlyCollection<T>()
return EmptyCollection<T>.Instance;
}/// <summary>A cached instance of an array of <typeparamref name = "T"/></summary>
public static T[] Array<T>()
return EmptyArray<T>.Instance;
}/// <summary>A cached readonly instance of <see cref = "IList{T}"/></summary>
public static IList<T> List<T>()
return EmptyCollection<T>.Instance;
}/// <summary>A cached instance of <see cref = "IReadOnlyList{T}"/></summary>
public static IReadOnlyList<T> ReadOnlyList<T>()
return EmptyCollection<T>.Instance;
}/// <summary>A cached readonly instance of <see cref = "ISet{T}"/></summary>
public static ISet<T> Set<T>()
return EmptyCollection<T>.Instance;
}/// <summary>A cached readonly instance of <see cref = "IDictionary{TKey, TValue}"/></summary>
public static IDictionary<TKey, TValue> Dictionary<TKey, TValue>()
return EmptyDictionary<TKey, TValue>.Instance;
}/// <summary>A cached instance of <see cref = "IReadOnlyDictionary{TKey, TValue}"/></summary>
public static IReadOnlyDictionary<TKey, TValue> ReadOnlyDictionary<TKey, TValue>()
return EmptyDictionary<TKey, TValue>.Instance;
#region ---- Empty Array ----
private static class EmptyArray<TElement>
// this takes advantage of the fact that Enumerable.Empty() is currently implemented
// using a cached empty array without depending on that fact
public static readonly TElement[] Instance = (System.Linq.Enumerable.Empty<TElement>() as TElement[]) ?? new TElement[0];
#region ---- Empty Collection ----
private class EmptyCollection<TElement> : IList<TElement>, IReadOnlyList<TElement>, ISet<TElement>, IEnumerator<TElement>, IList
public static readonly EmptyCollection<TElement> Instance = new EmptyCollection<TElement>();
protected EmptyCollection() { }
TElement IReadOnlyList<TElement>.this[int index]
get { throw ThrowCannotIndex(); }
object IList.this[int index]
get { throw ThrowCannotIndex(); }
set { throw ThrowReadOnly(); }
TElement IList<TElement>.this[int index]
get { throw ThrowCannotIndex(); }
set { throw ThrowReadOnly(); }
int IReadOnlyCollection<TElement>.Count
return 0;
int ICollection.Count
return 0;
int ICollection<TElement>.Count
return 0;
object IEnumerator.Current
// based on ((IEnumerator)new List<int>().GetEnumerator()).Current
get { throw new InvalidOperationException("Enumeration has either not started or has already finished"); }
// based on new List<int>().GetEnumerator().Current
TElement IEnumerator<TElement>.Current
return default(TElement);
bool IList.IsFixedSize
return true;
bool IList.IsReadOnly
return true;
bool ICollection<TElement>.IsReadOnly
return true;
bool ICollection.IsSynchronized
return false;
object ICollection.SyncRoot
return this;
int IList.Add(object value)
throw ThrowReadOnly();
bool ISet<TElement>.Add(TElement item)
throw ThrowReadOnly();
void ICollection<TElement>.Add(TElement item)
throw ThrowReadOnly();
void IList.Clear()
throw ThrowReadOnly();
void ICollection<TElement>.Clear()
throw ThrowReadOnly();
bool IList.Contains(object value)
return false;
bool ICollection<TElement>.Contains(TElement item)
return false;
void ICollection.CopyTo(Array array, int index)
if (array == null) { throw new ArgumentNullException("array"); }
if (index < 0 || index > array.Length) { throw new ArgumentOutOfRangeException("index"); }
void ICollection<TElement>.CopyTo(TElement[] array, int arrayIndex)
if (array == null) { throw new ArgumentNullException("array"); }
if (arrayIndex < 0 || arrayIndex > array.Length) { throw new ArgumentOutOfRangeException("arrayIndex"); }
void IDisposable.Dispose()
void ISet<TElement>.ExceptWith(IEnumerable<TElement> other)
throw ThrowReadOnly();
IEnumerator IEnumerable.GetEnumerator()
return this;
IEnumerator<TElement> IEnumerable<TElement>.GetEnumerator()
return this;
int IList.IndexOf(object value)
return -1;
int IList<TElement>.IndexOf(TElement item)
return -1;
void IList.Insert(int index, object value)
throw ThrowReadOnly();
void IList<TElement>.Insert(int index, TElement item)
throw ThrowReadOnly();
void ISet<TElement>.IntersectWith(IEnumerable<TElement> other)
throw ThrowReadOnly();
bool ISet<TElement>.IsProperSubsetOf(IEnumerable<TElement> other)
if (other == null) { throw new ArgumentNullException("other"); }
return other.Any();
bool ISet<TElement>.IsProperSupersetOf(IEnumerable<TElement> other)
if (other == null) { throw new ArgumentNullException("other"); }
return false;
bool ISet<TElement>.IsSubsetOf(IEnumerable<TElement> other)
if (other == null) { throw new ArgumentNullException("other"); }
return true;
bool ISet<TElement>.IsSupersetOf(IEnumerable<TElement> other)
if (other == null) { throw new ArgumentNullException("other"); }
return !other.Any();
bool IEnumerator.MoveNext()
return false;
bool ISet<TElement>.Overlaps(IEnumerable<TElement> other)
if (other == null) { throw new ArgumentNullException("other"); }
return false;
void IList.Remove(object value)
throw ThrowReadOnly();
bool ICollection<TElement>.Remove(TElement item)
throw ThrowReadOnly();
void IList.RemoveAt(int index)
throw ThrowReadOnly();
void IList<TElement>.RemoveAt(int index)
throw ThrowReadOnly();
void IEnumerator.Reset()
bool ISet<TElement>.SetEquals(IEnumerable<TElement> other)
if (other == null) { throw new ArgumentNullException("other"); }
return !other.Any();
void ISet<TElement>.SymmetricExceptWith(IEnumerable<TElement> other)
throw ThrowReadOnly();
void ISet<TElement>.UnionWith(IEnumerable<TElement> other)
throw ThrowReadOnly();
private static Exception ThrowCannotIndex()
throw new ArgumentOutOfRangeException("Cannot index into an empty collection");
#region ---- Empty Dictionary ----
private sealed class EmptyDictionary<TKey, TValue> : EmptyCollection<KeyValuePair<TKey, TValue>>, IReadOnlyDictionary<TKey, TValue>, IDictionary<TKey, TValue>, IDictionary, IDictionaryEnumerator
public static new readonly EmptyDictionary<TKey, TValue> Instance = new EmptyDictionary<TKey, TValue>();
private EmptyDictionary() { }
object IDictionary.this[object key]
get { return null; }
set { throw ThrowReadOnly(); }
TValue IDictionary<TKey, TValue>.this[TKey key]
get { throw new KeyNotFoundException(); }
set { throw ThrowReadOnly(); }
TValue IReadOnlyDictionary<TKey, TValue>.this[TKey key]
get { throw new KeyNotFoundException(); }
bool IDictionary.IsFixedSize
return true;
bool IDictionary.IsReadOnly
return true;
ICollection IDictionary.Keys
return ObjectCollection;
ICollection<TKey> IDictionary<TKey, TValue>.Keys
return Collection<TKey>();
IEnumerable<TKey> IReadOnlyDictionary<TKey, TValue>.Keys
return Collection<TKey>();
ICollection<TValue> IDictionary<TKey, TValue>.Values
return Collection<TValue>();
ICollection IDictionary.Values
return ObjectCollection;
IEnumerable<TValue> IReadOnlyDictionary<TKey, TValue>.Values
return ReadOnlyCollection<TValue>();
object IDictionaryEnumerator.Key
get { throw new InvalidOperationException("Enumeration has either not started or has already finished"); }
object IDictionaryEnumerator.Value
get { throw new InvalidOperationException("Enumeration has either not started or has already finished"); }
DictionaryEntry IDictionaryEnumerator.Entry
get { throw new InvalidOperationException("Enumeration has either not started or has already finished"); }
void IDictionary.Add(object key, object value)
throw ThrowReadOnly();
void IDictionary<TKey, TValue>.Add(TKey key, TValue value)
throw ThrowReadOnly();
void IDictionary.Clear()
throw ThrowReadOnly();
bool IDictionary.Contains(object key)
return false;
bool IDictionary<TKey, TValue>.ContainsKey(TKey key)
return false;
bool IReadOnlyDictionary<TKey, TValue>.ContainsKey(TKey key)
return false;
IDictionaryEnumerator IDictionary.GetEnumerator()
return this;
void IDictionary.Remove(object key)
throw ThrowReadOnly();
bool IDictionary<TKey, TValue>.Remove(TKey key)
throw ThrowReadOnly();
bool IDictionary<TKey, TValue>.TryGetValue(TKey key, out TValue value)
value = default(TValue);
return false;
bool IReadOnlyDictionary<TKey, TValue>.TryGetValue(TKey key, out TValue value)
value = default(TValue);
return false;
private static Exception ThrowReadOnly([CallerMemberName] string memberName = null)
throw new NotSupportedException(memberName + ": the collection is read-only");
#if MedallionCollections_USE_LOCAL_NAMESPACE
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Medallion.Tools.InlineNuGet", ""), global::System.Diagnostics.DebuggerNonUserCodeAttribute]
#if MedallionCollections_PUBLIC
static partial class EqualityComparers
#region ---- Func Comparer ----
/// <summary>
/// Creates an <see cref="EqualityComparer{T}"/> using the given <paramref name="equals"/> function
/// for equality and the optional <paramref name="hash"/> function for hashing (if <paramref name="hash"/> is not
/// provided, all values hash to 0). Note that null values are handled directly by the comparer and will not
/// be passed to these functions
/// </summary>
public static EqualityComparer<T> Create<T>(Func<T, T, bool> equals, Func<T, int> hash = null)
if (equals == null) { throw new ArgumentNullException("equals"); }
return new FuncEqualityComparer<T>(equals, hash);
private sealed class FuncEqualityComparer<T> : EqualityComparer<T>
private static readonly Func<T, int> DefaultHash = _ => -1;
private readonly Func<T, T, bool> equals;
private readonly Func<T, int> hash;
public FuncEqualityComparer(Func<T, T, bool> equals, Func<T, int> hash)
this.equals = equals;
this.hash = hash ?? DefaultHash;
public override bool Equals(T x, T y)
// TODO do these cause boxing?
// null checks consistent with Equals(object, object)
return x == null
? y == null
: y != null && this.equals(x, y);
public override int GetHashCode(T obj)
// consistent with GetHashCode(object)
return obj == null ? 0 : this.hash(obj);
public override bool Equals(object obj)
if (ReferenceEquals(obj, this)) { return true; }
var that = obj as FuncEqualityComparer<T>;
return that != null
&& that.equals.Equals(this.equals)
&& that.hash.Equals(this.hash);
public override int GetHashCode()
return unchecked((3 * this.equals.GetHashCode()) + this.hash.GetHashCode());
#region ---- Key Comparer ----
/// <summary>
/// Creates an <see cref="EqualityComparer{T}"/> which compares elements of type <typeparamref name="T"/> by projecting
/// them to an instance of type <typeparamref name="TKey"/> using the provided <paramref name="keySelector"/> and comparing/hashing
/// these keys. The optional <paramref name="keyComparer"/> argument can be used to specify how the keys are compared. Note that null
/// values are handled directly by the comparer and will not be passed to <paramref name="keySelector"/>
/// </summary>
public static EqualityComparer<T> Create<T, TKey>(Func<T, TKey> keySelector, IEqualityComparer<TKey> keyComparer = null)
if (keySelector == null) { throw new ArgumentNullException("keySelector"); }
return new KeyComparer<T, TKey>(keySelector, keyComparer);
private sealed class KeyComparer<T, TKey> : EqualityComparer<T>
private readonly Func<T, TKey> keySelector;
private readonly IEqualityComparer<TKey> keyComparer;
public KeyComparer(Func<T, TKey> keySelector, IEqualityComparer<TKey> keyComparer)
this.keySelector = keySelector;
this.keyComparer = keyComparer ?? EqualityComparer<TKey>.Default;
public override bool Equals(T x, T y)
if (x == null) { return y == null; }
if (y == null) { return false; }
return this.keyComparer.Equals(this.keySelector(x), this.keySelector(y));
public override int GetHashCode(T obj)
return obj == null ? 0 : this.keyComparer.GetHashCode(this.keySelector(obj));
public override bool Equals(object obj)
if (ReferenceEquals(obj, this)) { return true; }
var that = obj as KeyComparer<T, TKey>;
return that != null
&& that.keySelector.Equals(this.keySelector)
&& that.keyComparer.Equals(this.keyComparer);
public override int GetHashCode()
return unchecked((3 * this.keySelector.GetHashCode()) + this.keyComparer.GetHashCode());
#region ---- Reference Comparer ----
/// <summary>
/// Gets a cached <see cref="EqualityComparer{T}"/> instance which performs all comparisons by reference
/// (i. e. as if with <see cref="object.ReferenceEquals(object, object)"/>). Uses
/// <see cref="RuntimeHelpers.GetHashCode(object)"/> to emulate the native identity-based hash function
/// </summary>
public static EqualityComparer<T> GetReferenceComparer<T>()
where T : class
return ReferenceEqualityComparer<T>.Instance;
private sealed class ReferenceEqualityComparer<T> : EqualityComparer<T>
where T : class
public static readonly EqualityComparer<T> Instance = new ReferenceEqualityComparer<T>();
public override bool Equals(T x, T y)
return ReferenceEquals(x, y);
public override int GetHashCode(T obj)
// handles nulls
return RuntimeHelpers.GetHashCode(obj);
#region ---- Collection Comparer ----
/// <summary>
/// Gets an <see cref="EqualityComparer{T}"/> that compares instances of <see cref="IEnumerable{TElement}"/> as
/// if with <see cref="CollectionHelper.CollectionEquals{TElement}(IEnumerable{TElement}, IEnumerable{TElement}, IEqualityComparer{TElement})"/>.
/// The optional <paramref name="elementComparer"/> can be used to override the comparison of individual elements
/// </summary>
public static EqualityComparer<IEnumerable<TElement>> GetCollectionComparer<TElement>(IEqualityComparer<TElement> elementComparer = null)
return elementComparer == null || elementComparer == EqualityComparer<TElement>.Default
? CollectionComparer<TElement>.DefaultInstance
: new CollectionComparer<TElement>(elementComparer);
private sealed class CollectionComparer<TElement> : EqualityComparer<IEnumerable<TElement>>
private static EqualityComparer<IEnumerable<TElement>> defaultInstance;
public static EqualityComparer<IEnumerable<TElement>> DefaultInstance
return defaultInstance ?? (defaultInstance = new CollectionComparer<TElement>(EqualityComparer<TElement>.Default));
private readonly IEqualityComparer<TElement> elementComparer;
public CollectionComparer(IEqualityComparer<TElement> elementComparer)
this.elementComparer = elementComparer;
public override bool Equals(IEnumerable<TElement> x, IEnumerable<TElement> y)
if (x == null)
return y == null;
else if (y == null)
return false;
// avoid calling as extension to support DISABLE_EXTENSIONs in the Inline package
return CollectionHelper.CollectionEquals(x, y, this.elementComparer);
public override int GetHashCode(IEnumerable<TElement> obj)
return obj != null
// combine hashcodes with xor to be order-insensitive
? obj.Aggregate(-1, (hash, element) => hash ^ this.elementComparer.GetHashCode(element))
: 0;
public override bool Equals(object obj)
if (ReferenceEquals(obj, this)) { return true; }
var that = obj as CollectionComparer<TElement>;
return that != null && that.elementComparer.Equals(this.elementComparer);
public override int GetHashCode()
return ReferenceEquals(this, DefaultInstance)
? base.GetHashCode()
: unchecked((3 * DefaultInstance.GetHashCode()) + this.elementComparer.GetHashCode());
#region ---- Sequence Comparer ----
/// <summary>
/// Gets an <see cref="EqualityComparer{T}"/> which compares instances of <see cref="IEnumerable{TElement}"/> as if
/// with <see cref="Enumerable.SequenceEqual{TSource}(IEnumerable{TSource}, IEnumerable{TSource})"/>. The optional
/// <paramref name="elementComparer"/> can be used to override the comparison of individual elements
/// </summary>
public static EqualityComparer<IEnumerable<TElement>> GetSequenceComparer<TElement>(IEqualityComparer<TElement> elementComparer = null)
return elementComparer == null || elementComparer == EqualityComparer<TElement>.Default
? SequenceComparer<TElement>.DefaultInstance
: new SequenceComparer<TElement>(elementComparer);
private sealed class SequenceComparer<TElement> : EqualityComparer<IEnumerable<TElement>>
private static EqualityComparer<IEnumerable<TElement>> defaultInstance;
public static EqualityComparer<IEnumerable<TElement>> DefaultInstance
return defaultInstance ?? (defaultInstance = new SequenceComparer<TElement>(EqualityComparer<TElement>.Default));
private readonly IEqualityComparer<TElement> elementComparer;
public SequenceComparer(IEqualityComparer<TElement> elementComparer)
this.elementComparer = elementComparer;
public override bool Equals(IEnumerable<TElement> x, IEnumerable<TElement> y)
if (x == null)
return y == null;
if (y == null)
return false;
return x.SequenceEqual(y, this.elementComparer);
public override int GetHashCode(IEnumerable<TElement> obj)
return obj != null
// hash combine logic based on .NET Tuple.CombineHashCodes
? obj.Aggregate(-1, (hash, element) => (((hash << 5) + hash) ^ this.elementComparer.GetHashCode(element)))
: 0;
public override bool Equals(object obj)
if (ReferenceEquals(obj, this)) { return true; }
var that = obj as SequenceComparer<TElement>;
return that != null && that.elementComparer.Equals(this.elementComparer);
public override int GetHashCode()
return ReferenceEquals(this, DefaultInstance)
? base.GetHashCode()
: unchecked((3 * DefaultInstance.GetHashCode()) + this.elementComparer.GetHashCode());
#if MedallionCollections_USE_LOCAL_NAMESPACE
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Medallion.Tools.InlineNuGet", ""), global::System.Diagnostics.DebuggerNonUserCodeAttribute]
#if MedallionCollections_PUBLIC
static partial class Traverse
#region ---- Along ----
/// <summary>
/// Enumerates the implicit sequence starting from <paramref name="root"/>
/// and following the chain of <paramref name="next"/> calls until a null value
/// is encountered. For example, this can be used to traverse a chain of exceptions:
/// <code>
/// var innermostException = Traverse.Along(exception, e => e.InnerException).Last();
/// </code>
/// </summary>
public static IEnumerable<T> Along<T>(T root, Func<T, T> next)
where T : class
if (next == null) { throw new ArgumentNullException("next"); }
return AlongIterator(root, next);
private static IEnumerable<T> AlongIterator<T>(T root, Func<T, T> next)
for (var node = root; node != null; node = next(node))
yield return node;
#region ---- Breadth-First ----
/// <summary>
/// Enumerates the implicit tree described by <paramref name="root"/> and <paramref name="children"/>
/// in a breadth-first manner. For example, this could be used to enumerate the exceptions of an
/// <see cref="AggregateException"/>:
/// <code>
/// var allExceptions = Traverse.BreadthFirst((Exception)new AggregateException(), e => (e as AggregateException)?.InnerExceptions ?? Enumerable.Empty&lt;Exception&gt;());
/// </code>
/// </summary>
public static IEnumerable<T> BreadthFirst<T>(T root, Func<T, IEnumerable<T>> children)
if (children == null) { throw new ArgumentNullException("children"); }
return BreadthFirstIterator(root, children);
private static IEnumerable<T> BreadthFirstIterator<T>(T root, Func<T, IEnumerable<T>> children)
// note that this implementation has two nice properties which require a bit more complexity
// in the code: (1) children are yielded in order and (2) child enumerators are fully lazy
yield return root;
var queue = new Queue<IEnumerable<T>>();
foreach (var child in queue.Dequeue())
yield return child;
while (queue.Count > 0);
#region ---- Depth-First ----
/// <summary>
/// Enumerates the implicit tree described by <paramref name="root"/> and <paramref name="children"/>
/// in a depth-first manner. For example, this could be used to enumerate the exceptions of an
/// <see cref="AggregateException"/>:
/// <code>
/// var allExceptions = Traverse.DepthFirst((Exception)new AggregateException(), e => (e as AggregateException)?.InnerExceptions ?? Enumerable.Empty&lt;Exception&gt;());
/// </code>
/// </summary>
public static IEnumerable<T> DepthFirst<T>(T root, Func<T, IEnumerable<T>> children)
if (children == null) { throw new ArgumentNullException("children"); }
return DepthFirstIterator(root, children);
private static IEnumerable<T> DepthFirstIterator<T>(T root, Func<T, IEnumerable<T>> children)
// note that this implementation has two nice properties which require a bit more complexity
// in the code: (1) children are yielded in order and (2) child enumerators are fully lazy
var current = root;
var stack = new Stack<IEnumerator<T>>();
while (true)
yield return current;
var childrenEnumerator = children(current).GetEnumerator();
if (childrenEnumerator.MoveNext())
// if we have children, the first child is our next current
// and push the new enumerator
current = childrenEnumerator.Current;
// otherwise, cleanup the empty enumerator and...
// search up the stack for an enumerator with elements left
while (true)
if (stack.Count == 0)
// we didn't find one, so we're all done
yield break;
var topEnumerator = stack.Peek();
if (topEnumerator.MoveNext())
current = topEnumerator.Current;
// guarantee that everything is cleaned up even
// if we don't enumerate all the way through
while (stack.Count > 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment