Skip to content

Instantly share code, notes, and snippets.

@Usualdosage
Created July 15, 2021 18:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save Usualdosage/a1717f776b593d177e410d40cc23bdd9 to your computer and use it in GitHub Desktop.
Save Usualdosage/a1717f776b593d177e410d40cc23bdd9 to your computer and use it in GitHub Desktop.
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Query;
using Moq;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
namespace MyProject.Extensions
{
public static class EnumerableExtensions
{
public static IQueryable<T> AsAsyncQueryable<T>(this IEnumerable<T> source)
{
return new AsyncQueryableWrapper<T>(source);
}
public static IQueryable<T> AsAsyncQueryable<T>(this IQueryable<T> source)
{
return new AsyncQueryableWrapper<T>(source);
}
public static Mock<DbSet<T>> ToAsyncDbSetMock<T>(this IEnumerable<T> source) where T : class
{
var data = source.AsQueryable();
var mockSet = new Mock<DbSet<T>>();
mockSet.As<IAsyncEnumerable<T>>()
.Setup(m => m.GetAsyncEnumerator(new CancellationToken()))
.Returns(new TestAsyncEnumerator<T>(data.GetEnumerator()));
mockSet.As<IQueryable<T>>()
.Setup(m => m.Provider)
.Returns(new TestAsyncQueryProvider<T>(data.Provider));
mockSet.As<IQueryable<T>>().Setup(m => m.Expression).Returns(data.Expression);
mockSet.As<IQueryable<T>>().Setup(m => m.ElementType).Returns(data.ElementType);
mockSet.As<IQueryable<T>>().Setup(m => m.GetEnumerator()).Returns(() => data.GetEnumerator());
return mockSet;
}
}
internal class AsyncQueryableWrapper<T> : IDbAsyncEnumerable<T>, IQueryable<T>
{
private readonly IQueryable<T> _source;
public AsyncQueryableWrapper(IQueryable<T> source)
{
_source = source;
}
public AsyncQueryableWrapper(IEnumerable<T> source)
{
_source = source.AsQueryable();
}
public IDbAsyncEnumerator<T> GetAsyncEnumerator()
{
return new AsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return GetAsyncEnumerator();
}
public IEnumerator<T> GetEnumerator()
{
return _source.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public Expression Expression => _source.Expression;
public Type ElementType => _source.ElementType;
public IQueryProvider Provider => new AsyncQueryProvider<T>(_source.Provider);
}
internal class AsyncEnumerable<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T>, IQueryable<T>
{
public AsyncEnumerable(IEnumerable<T> enumerable)
: base(enumerable)
{ }
public AsyncEnumerable(Expression expression)
: base(expression)
{ }
public IDbAsyncEnumerator<T> GetAsyncEnumerator()
{
return new AsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return GetAsyncEnumerator();
}
IQueryProvider IQueryable.Provider => new AsyncQueryProvider<T>(this);
}
internal class AsyncQueryProvider<TEntity> : IDbAsyncQueryProvider
{
private readonly IQueryProvider _inner;
internal AsyncQueryProvider(IQueryProvider inner)
{
_inner = inner;
}
public IQueryable CreateQuery(Expression expression)
{
var t = expression.Type;
if (!t.IsGenericType)
{
return new AsyncEnumerable<TEntity>(expression);
}
var genericParams = t.GetGenericArguments();
var genericParam = genericParams[0];
var enumerableType = typeof(AsyncEnumerable<>).MakeGenericType(genericParam);
return (IQueryable)Activator.CreateInstance(enumerableType, expression);
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new AsyncEnumerable<TElement>(expression);
}
public object Execute(Expression expression)
{
return _inner.Execute(expression);
}
public TResult Execute<TResult>(Expression expression)
{
return _inner.Execute<TResult>(expression);
}
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute(expression));
}
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute<TResult>(expression));
}
}
internal class AsyncEnumerator<T> : IDbAsyncEnumerator<T>
{
private readonly IEnumerator<T> _inner;
public AsyncEnumerator(IEnumerator<T> inner)
{
_inner = inner;
}
public void Dispose()
{
_inner.Dispose();
}
public Task<bool> MoveNextAsync(CancellationToken cancellationToken)
{
return Task.FromResult(_inner.MoveNext());
}
public T Current => _inner.Current;
object IDbAsyncEnumerator.Current => Current;
}
internal class TestAsyncQueryProvider<TEntity> : IAsyncQueryProvider
{
private readonly IQueryProvider _inner;
internal TestAsyncQueryProvider(IQueryProvider inner)
{
_inner = inner;
}
public IQueryable CreateQuery(Expression expression)
{
return new TestAsyncEnumerable<TEntity>(expression);
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new TestAsyncEnumerable<TElement>(expression);
}
public object Execute(Expression expression)
{
return _inner.Execute(expression);
}
public TResult Execute<TResult>(Expression expression)
{
return _inner.Execute<TResult>(expression);
}
public IAsyncEnumerable<TResult> ExecuteAsync<TResult>(Expression expression)
{
return new TestAsyncEnumerable<TResult>(expression);
}
public TResult ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
object returnValue = Execute(expression);
return ConvertToThreadingTResult<TResult>(returnValue);
}
TResult IAsyncQueryProvider.ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
var returnValue = ExecuteAsync<TResult>(expression, default);
return ConvertToTResult<TResult>(returnValue);
}
private static TR ConvertToTResult<TR>(dynamic toConvert) => (TR)toConvert;
private static TR ConvertToThreadingTResult<TR>(dynamic toConvert) => (TR)Task.FromResult(toConvert);
}
internal class TestAsyncEnumerable<T> : EnumerableQuery<T>, IAsyncEnumerable<T>, IQueryable<T>
{
public TestAsyncEnumerable(IEnumerable<T> enumerable)
: base(enumerable)
{ }
public TestAsyncEnumerable(Expression expression)
: base(expression)
{ }
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken token)
{
return new TestAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
IQueryProvider IQueryable.Provider => new TestAsyncQueryProvider<T>(this);
}
internal class TestAsyncEnumerator<T> : IAsyncEnumerator<T>
{
private readonly IEnumerator<T> _inner;
public TestAsyncEnumerator(IEnumerator<T> inner)
{
_inner = inner;
}
public T Current
{
get
{
return _inner.Current;
}
}
public Task<bool> MoveNext(CancellationToken cancellationToken)
{
return Task.FromResult(_inner.MoveNext());
}
public ValueTask<bool> MoveNextAsync()
{
return new ValueTask<bool>(Task.FromResult(_inner.MoveNext()));
}
public ValueTask DisposeAsync()
{
return new ValueTask(Task.Run(() => _inner.Dispose()));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment