Skip to content

Instantly share code, notes, and snippets.

@AArnott
Created May 26, 2009 23:09
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save AArnott/118348 to your computer and use it in GitHub Desktop.
Save AArnott/118348 to your computer and use it in GitHub Desktop.
Get the benefit of deferred execution from generator methods without the cost of multiple generation when enumerated twice
//-----------------------------------------------------------------------
// <copyright file="EnumerableCache.cs" company="Andrew Arnott">
// Copyright (c) Andrew Arnott. All rights reserved.
// This code is released under the Microsoft Public License (Ms-PL).
// </copyright>
//-----------------------------------------------------------------------
namespace IEnumeratorCache {
using System;
using System.Collections;
using System.Collections.Generic;
/// <summary>
/// Extension methods for <see cref="IEnumerable&lt;T&gt;"/> types.
/// </summary>
public static class EnumerableCacheExtensions {
/// <summary>
/// Caches the results of enumerating over a given object so that subsequence enumerations
/// don't require interacting with the object a second time.
/// </summary>
/// <typeparam name="T">The type of element found in the enumeration.</typeparam>
/// <param name="sequence">The enumerable object.</param>
/// <returns>
/// Either a new enumerable object that caches enumerated results, or the original, <paramref name="sequence"/>
/// object if no caching is necessary to avoid additional CPU work.
/// </returns>
/// <remarks>
/// <para>This is designed for use on the results of generator methods (the ones with <c>yield return</c> in them)
/// so that only those elements in the sequence that are needed are ever generated, while not requiring
/// regeneration of elements that are enumerated over multiple times.</para>
/// <para>This can be a huge performance gain if enumerating multiple times over an expensive generator method.</para>
/// <para>Some enumerable types such as collections, lists, and already-cached generators do not require
/// any (additional) caching, and this method will simply return those objects rather than caching them
/// to avoid double-caching.</para>
/// </remarks>
public static IEnumerable<T> CacheGeneratedResults<T>(this IEnumerable<T> sequence) {
// Don't create a cache for types that don't need it.
if (sequence is IList<T> ||
sequence is ICollection<T> ||
sequence is Array ||
sequence is EnumerableCache<T>) {
return sequence;
}
return new EnumerableCache<T>(sequence);
}
/// <summary>
/// A wrapper for <see cref="IEnumerable&lt;T&gt;"/> types and returns a caching <see cref="IEnumerator&lt;T&gt;"/>
/// from its <see cref="IEnumerable&lt;T&gt;.GetEnumerator"/> method.
/// </summary>
/// <typeparam name="T">The type of element in the sequence.</typeparam>
private class EnumerableCache<T> : IEnumerable<T> {
/// <summary>
/// The results from enumeration of the live object that have been collected thus far.
/// </summary>
private List<T> cache;
/// <summary>
/// The original generator method or other enumerable object whose contents should only be enumerated once.
/// </summary>
private IEnumerable<T> generator;
/// <summary>
/// The enumerator we're using over the generator method's results.
/// </summary>
private IEnumerator<T> generatorEnumerator;
/// <summary>
/// The sync object our caching enumerators use when adding a new live generator method result to the cache.
/// </summary>
/// <remarks>
/// Although individual enumerators are not thread-safe, this <see cref="IEnumerable&lt;T&gt;"/> should be
/// thread safe so that multiple enumerators can be created from it and used from different threads.
/// </remarks>
private object generatorLock = new object();
/// <summary>
/// Initializes a new instance of the EnumerableCache class.
/// </summary>
/// <param name="generator">The generator.</param>
internal EnumerableCache(IEnumerable<T> generator) {
if (generator == null) {
throw new ArgumentNullException("generator");
}
this.generator = generator;
}
#region IEnumerable<T> Members
/// <summary>
/// Returns an enumerator that iterates through the collection.
/// </summary>
/// <returns>
/// A <see cref="T:System.Collections.Generic.IEnumerator`1"/> that can be used to iterate through the collection.
/// </returns>
public IEnumerator<T> GetEnumerator() {
if (this.generatorEnumerator == null) {
this.cache = new List<T>();
this.generatorEnumerator = this.generator.GetEnumerator();
}
return new EnumeratorCache(this);
}
#endregion
#region IEnumerable Members
/// <summary>
/// Returns an enumerator that iterates through a collection.
/// </summary>
/// <returns>
/// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection.
/// </returns>
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() {
return this.GetEnumerator();
}
#endregion
/// <summary>
/// An enumerator that uses cached enumeration results whenever they are available,
/// and caches whatever results it has to pull from the original <see cref="IEnumerable&lt;T&gt;"/> object.
/// </summary>
private class EnumeratorCache : IEnumerator<T> {
/// <summary>
/// The parent enumeration wrapper class that stores the cached results.
/// </summary>
private EnumerableCache<T> parent;
/// <summary>
/// The position of this enumerator in the cached list.
/// </summary>
private int cachePosition = -1;
/// <summary>
/// Initializes a new instance of the <see cref="EnumerableCache&lt;T&gt;.EnumeratorCache"/> class.
/// </summary>
/// <param name="parent">The parent cached enumerable whose GetEnumerator method is calling this constructor.</param>
internal EnumeratorCache(EnumerableCache<T> parent) {
if (parent == null) {
throw new ArgumentNullException("parent");
}
this.parent = parent;
}
#region IEnumerator<T> Members
/// <summary>
/// Gets the element in the collection at the current position of the enumerator.
/// </summary>
/// <returns>
/// The element in the collection at the current position of the enumerator.
/// </returns>
public T Current {
get {
if (this.cachePosition < 0 || this.cachePosition >= this.parent.cache.Count) {
throw new InvalidOperationException();
}
return this.parent.cache[this.cachePosition];
}
}
#endregion
#region IEnumerator Properties
/// <summary>
/// Gets the element in the collection at the current position of the enumerator.
/// </summary>
/// <returns>
/// The element in the collection at the current position of the enumerator.
/// </returns>
object System.Collections.IEnumerator.Current {
get { return this.Current; }
}
#endregion
#region IDisposable Members
/// <summary>
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
/// </summary>
public void Dispose() {
this.Dispose(true);
GC.SuppressFinalize(this);
}
#endregion
#region IEnumerator Methods
/// <summary>
/// Advances the enumerator to the next element of the collection.
/// </summary>
/// <returns>
/// true if the enumerator was successfully advanced to the next element; false if the enumerator has passed the end of the collection.
/// </returns>
/// <exception cref="T:System.InvalidOperationException">
/// The collection was modified after the enumerator was created.
/// </exception>
public bool MoveNext() {
this.cachePosition++;
if (this.cachePosition >= this.parent.cache.Count) {
lock (this.parent.generatorLock) {
if (this.cachePosition >= this.parent.cache.Count) {
if (this.parent.generatorEnumerator.MoveNext()) {
this.parent.cache.Add(this.parent.generatorEnumerator.Current);
} else {
return false;
}
}
}
}
return true;
}
/// <summary>
/// Sets the enumerator to its initial position, which is before the first element in the collection.
/// </summary>
/// <exception cref="T:System.InvalidOperationException">
/// The collection was modified after the enumerator was created.
/// </exception>
public void Reset() {
this.cachePosition = -1;
}
#endregion
/// <summary>
/// Releases unmanaged and - optionally - managed resources
/// </summary>
/// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
protected virtual void Dispose(bool disposing) {
// Nothing to do here.
}
}
}
}
}
//-----------------------------------------------------------------------
// <copyright file="EnumerableCacheTests.cs" company="Andrew Arnott">
// Copyright (c) Andrew Arnott. All rights reserved.
// This code is released under the Microsoft Public License (Ms-PL).
// </copyright>
//-----------------------------------------------------------------------
namespace IEnumeratorCache {
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using Microsoft.VisualStudio.TestTools.UnitTesting;
/// <summary>
/// Tests for cached enumeration.
/// </summary>
[TestClass]
public class EnumerableCacheTests {
/// <summary>
/// The number of times the generator method's implementation is started.
/// </summary>
private int generatorInvocations;
/// <summary>
/// The number of times the end of the generator method's implementation is reached.
/// </summary>
private int generatorCompleted;
/// <summary>
/// Gets or sets the test context.
/// </summary>
public TestContext TestContext { get; set; }
/// <summary>
/// Sets up a test.
/// </summary>
[TestInitialize]
public void Setup() {
this.generatorInvocations = 0;
this.generatorCompleted = 0;
}
[TestMethod]
public void EnumerableCache() {
// Baseline
var generator = this.NumberGenerator();
var list1 = generator.ToList();
var list2 = generator.ToList();
Assert.AreEqual(2, this.generatorInvocations);
CollectionAssert.AreEqual(list1, list2);
// Cache behavior
this.generatorInvocations = 0;
this.generatorCompleted = 0;
generator = this.NumberGenerator().CacheGeneratedResults();
var list3 = generator.ToList();
var list4 = generator.ToList();
Assert.AreEqual(1, this.generatorInvocations);
Assert.AreEqual(1, this.generatorCompleted);
CollectionAssert.AreEqual(list1, list3);
CollectionAssert.AreEqual(list1, list4);
}
[TestMethod]
public void GeneratesOnlyRequiredElements() {
var generator = this.NumberGenerator().CacheGeneratedResults();
Assert.AreEqual(0, this.generatorInvocations);
generator.Take(2).ToList();
Assert.AreEqual(1, this.generatorInvocations);
Assert.AreEqual(0, this.generatorCompleted, "Only taking part of the list should not have completed the generator.");
}
[TestMethod]
public void PassThruDoubleCache() {
var cache1 = this.NumberGenerator().CacheGeneratedResults();
var cache2 = cache1.CacheGeneratedResults();
Assert.AreSame(cache1, cache2, "Two caches were set up rather than just sharing the first one.");
}
[TestMethod]
public void PassThruList() {
var list = this.NumberGenerator().ToList();
var cache = list.CacheGeneratedResults();
Assert.AreSame(list, cache);
}
[TestMethod]
public void PassThruArray() {
var array = this.NumberGenerator().ToArray();
var cache = array.CacheGeneratedResults();
Assert.AreSame(array, cache);
}
[TestMethod]
public void PassThruCollection() {
var collection = new Collection<int>();
var cache = collection.CacheGeneratedResults();
Assert.AreSame(collection, cache);
}
/// <summary>
/// Tests calling IEnumerator.Current before first call to MoveNext.
/// </summary>
[TestMethod, ExpectedException(typeof(InvalidOperationException))]
public void EnumerableCacheCurrentThrowsBefore() {
var foo = this.NumberGenerator().CacheGeneratedResults().GetEnumerator().Current;
}
/// <summary>
/// Tests calling IEnumerator.Current after MoveNext returns false.
/// </summary>
[TestMethod, ExpectedException(typeof(InvalidOperationException))]
public void EnumerableCacheCurrentThrowsAfter() {
var enumerator = this.NumberGenerator().CacheGeneratedResults().GetEnumerator();
while (enumerator.MoveNext()) {
}
var foo = enumerator.Current;
}
private IEnumerable<int> NumberGenerator() {
this.generatorInvocations++;
for (int i = 10; i < 15; i++) {
yield return i;
}
this.generatorCompleted++;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment