Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
MemoryDbSet for DbContext Unit Testing
using System.Collections.Generic;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;
namespace Barkeep.Tests
{
public class MemoryDbSet<T> : DbSet<T>, IQueryable<T> where T : class
{
#region Private Members
private readonly HashSet<T> _set;
#endregion
#region Constructors
public MemoryDbSet() : this(Enumerable.Empty<T>())
{
}
public MemoryDbSet(IEnumerable<T> entities)
{
_set = new HashSet<T>();
foreach (var entity in entities)
{
_set.Add(entity);
}
var queryable = _set.AsQueryable();
Provider = queryable.Provider;
Expression = queryable.Expression;
}
#endregion
#region Public Methods
public override T Add(T entity)
{
_set.Add(entity);
return entity;
}
public override T Remove(T entity)
{
_set.Remove(entity);
return entity;
}
public override T Find(params object[] keyValues)
{
var id = -1;
if (keyValues.Count() == 1)
{
id = keyValues.First() is int ? (int) keyValues.First() : -1;
}
return _set.SingleOrDefault(x =>
(x.GetType().GetProperty("ID", typeof (int)).GetValue(x) is int
? (int) x.GetType().GetProperty("ID", typeof (int)).GetValue(x)
: 0) == id);
}
public override IEnumerable<T> RemoveRange(IEnumerable<T> entities)
{
_set.RemoveWhere(entities.Contains);
return _set;
}
public IEnumerator<T> GetEnumerator()
{
return _set.AsQueryable().GetEnumerator();
}
public IQueryProvider Provider { get; private set; }
public Expression Expression { get; private set; }
#endregion
}
#region Extensions
public static class Extensions
{
public static DbSet<T> ToMockDbSet<T>(this IEnumerable<T> list) where T : class
{
return new MemoryDbSet<T>(list);
}
}
#endregion
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment