Skip to content

Instantly share code, notes, and snippets.

@mattpodwysocki
Created February 7, 2018 19:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattpodwysocki/66a887fcefa3201146d9acedf5cf4c12 to your computer and use it in GitHub Desktop.
Save mattpodwysocki/66a887fcefa3201146d9acedf5cf4c12 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using System.Data;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
namespace MobCat.Data.Repositories
{
public interface IDataContext : IDisposable
{
DatabaseFacade Database { get; }
DbSet<T> Set<T>() where T : class;
void SyncObjectState<TEntity>(TEntity entity) where TEntity : class, IObjectState;
void SyncObjectsStatePostCommit();
int SaveChanges();
Task<int> SaveChangesAsync(CancellationToken cancellationToken = default(CancellationToken));
}
public class DataContext : DbContext, IDataContext
{
private readonly Guid _instanceId;
public DataContext(DbContextOptions options) : base(options)
{
_instanceId = Guid.NewGuid();
}
public Guid InstanceId => _instanceId;
public override int SaveChanges()
{
SyncObjectsStatePreCommit();
var changes = base.SaveChanges();
SyncObjectsStatePostCommit();
return changes;
}
public override async Task<int> SaveChangesAsync(CancellationToken cancellationToken = default(CancellationToken))
{
SyncObjectsStatePreCommit();
var changesAsync = await base.SaveChangesAsync(cancellationToken);
SyncObjectsStatePostCommit();
return changesAsync;
}
private void SyncObjectsStatePreCommit()
{
foreach (var dbEntityEntry in ChangeTracker.Entries())
{
dbEntityEntry.State = StateConverter.ConvertState(((IObjectState)dbEntityEntry.Entity).ObjectState);
}
}
public void SyncObjectsStatePostCommit()
{
foreach (var dbEntityEntry in ChangeTracker.Entries())
{
((IObjectState)dbEntityEntry.Entity).ObjectState = StateConverter.ConvertState(dbEntityEntry.State);
}
}
public void SyncObjectState<TEntity>(TEntity entity) where TEntity : class, IObjectState
{
Entry(entity).State = StateConverter.ConvertState(entity.ObjectState);
}
}
public enum ObjectState
{
Unchanged,
Added,
Modified,
Deleted
}
public static class StateConverter
{
public static EntityState ConvertState(ObjectState state)
{
switch (state)
{
case ObjectState.Added:
return EntityState.Added;
case ObjectState.Modified:
return EntityState.Modified;
case ObjectState.Deleted:
return EntityState.Deleted;
default:
return EntityState.Unchanged;
}
}
public static ObjectState ConvertState(EntityState state)
{
switch (state)
{
case EntityState.Detached:
return ObjectState.Unchanged;
case EntityState.Unchanged:
return ObjectState.Unchanged;
case EntityState.Added:
return ObjectState.Added;
case EntityState.Deleted:
return ObjectState.Deleted;
case EntityState.Modified:
return ObjectState.Modified;
default:
throw new ArgumentOutOfRangeException(nameof(state));
}
}
}
public interface IObjectState
{
[NotMapped]
ObjectState ObjectState { get; set; }
}
public interface IEntityRepository<TEntity> where TEntity : class, IObjectState
{
Task<List<TEntity>> FindAsync<TOrderKey, TThenKey>(
Expression<Func<TEntity, bool>> predicate,
Expression<Func<TEntity, TOrderKey>> orderBy = null,
Expression<Func<TEntity, TOrderKey>> orderByDescending = null,
Expression<Func<TEntity, TThenKey>> thenBy = null,
Expression<Func<TEntity, TThenKey>> thenByDescending = null,
CancellationToken cancellationToken = default(CancellationToken));
Task<TEntity> FindByIdAsync(params object[] keyValues);
Task<TEntity> FindByIdAsync(CancellationToken cancellationToken, params object[] keyValues);
void Delete(TEntity entity);
Task<bool> DeleteAsync(params object[] keyValues);
Task<bool> DeleteAsync(CancellationToken cancellationToken, params object[] keyValues);
void Insert(TEntity entity);
void Update(TEntity entity);
}
public class EntityRepository<TEntity> : IEntityRepository<TEntity> where TEntity : class, IObjectState
{
readonly IDataContext _context;
readonly DbSet<TEntity> _set;
public EntityRepository(IDataContext context)
{
_context = context;
_set = context.Set<TEntity>();
}
public void Delete(TEntity entity)
{
entity.ObjectState = ObjectState.Deleted;
_set.Attach(entity);
_context.SyncObjectState(entity);
}
public Task<bool> DeleteAsync(params object[] keyValues) => DeleteAsync(CancellationToken.None, keyValues);
public async Task<bool> DeleteAsync(CancellationToken cancellationToken, params object[] keyValues)
{
var entity = await _set.FindAsync(cancellationToken, keyValues).ConfigureAwait(false);
if (entity == null)
{
return false;
}
Delete(entity);
return true;
}
public Task<List<TEntity>> FindAsync<TOrderKey, TThenKey>(
Expression<Func<TEntity, bool>> predicate,
Expression<Func<TEntity, TOrderKey>> orderBy = null,
Expression<Func<TEntity, TOrderKey>> orderByDescending = null,
Expression<Func<TEntity, TThenKey>> thenBy = null,
Expression<Func<TEntity, TThenKey>> thenByDescending = null,
CancellationToken cancellationToken = default(CancellationToken))
{
IQueryable<TEntity> query = _set;
if (predicate != null)
{
query = query.Where(predicate);
}
if (orderBy != null || orderByDescending != null)
{
if (orderBy != null)
{
query = query.OrderBy(orderBy);
}
else if (orderByDescending != null)
{
query = query.OrderByDescending(orderByDescending);
}
if (thenBy != null)
{
query = ((IOrderedQueryable<TEntity>)query).ThenBy(thenBy);
}
else if (thenByDescending != null)
{
query = ((IOrderedQueryable<TEntity>)query).ThenByDescending(thenByDescending);
}
}
return query.ToListAsync(cancellationToken);
}
public Task<TEntity> FindByIdAsync(params object[] keyValues) => FindByIdAsync(CancellationToken.None, keyValues);
public Task<TEntity> FindByIdAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return _set.FindAsync(cancellationToken, keyValues);
}
public void Insert(TEntity entity)
{
entity.ObjectState = ObjectState.Added;
_set.Attach(entity);
_context.SyncObjectState(entity);
}
public void Update(TEntity entity)
{
entity.ObjectState = ObjectState.Modified;
_set.Attach(entity);
_context.SyncObjectState(entity);
}
}
public interface IUnitOfWork : IDisposable
{
IEntityRepository<TEntity> Repository<TEntity>() where TEntity : class, IObjectState;
int Save();
Task<int> SaveAsync(CancellationToken cancellationToken = default(CancellationToken));
void Commit();
void Rollback();
void BeginTransaction(IsolationLevel isolationLevel = IsolationLevel.Unspecified);
Task BeginTransactionAsync(IsolationLevel isolationLevel = IsolationLevel.Unspecified, CancellationToken cancellationToken = default(CancellationToken));
}
public class UnitOfWork : IUnitOfWork
{
IDataContext _dataContext;
IDbContextTransaction _transaction;
Dictionary<string, dynamic> _repositories;
bool _disposed = false;
public UnitOfWork(IDataContext dataContext)
{
_dataContext = dataContext;
_repositories = new Dictionary<string, dynamic>();
}
public void BeginTransaction(IsolationLevel isolationLevel = IsolationLevel.Unspecified)
{
_dataContext.Database.OpenConnection();
_transaction = _dataContext.Database.BeginTransaction(isolationLevel);
}
public async Task BeginTransactionAsync(IsolationLevel isolationLevel = IsolationLevel.Unspecified, CancellationToken cancellationToken = default(CancellationToken))
{
await _dataContext.Database.OpenConnectionAsync(cancellationToken);
_transaction = await _dataContext.Database.BeginTransactionAsync(isolationLevel, cancellationToken);
}
public void Commit()
{
_transaction.Commit();
}
public void Rollback()
{
_transaction.Rollback();
_dataContext.SyncObjectsStatePostCommit();
}
public int Save()
{
return _dataContext.SaveChanges();
}
public Task<int> SaveAsync(CancellationToken cancellationToken = default(CancellationToken))
{
return _dataContext.SaveChangesAsync(cancellationToken);
}
public IEntityRepository<TEntity> Repository<TEntity>() where TEntity : class, IObjectState
{
if (_repositories == null)
{
_repositories = new Dictionary<string, dynamic>();
}
var type = typeof(TEntity).Name;
if (_repositories.ContainsKey(type))
{
return (IEntityRepository<TEntity>)_repositories[type];
}
var repositoryType = typeof(IEntityRepository<>);
_repositories.Add(type, Activator.CreateInstance(repositoryType.MakeGenericType(typeof(TEntity)), _dataContext));
return _repositories[type];
}
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
_dataContext.Dispose();
_dataContext = null;
_transaction?.Dispose();
_transaction = null;
_repositories = null;
}
_disposed = true;
}
}
public void Dispose()
{
Dispose(true);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment