Skip to content

Instantly share code, notes, and snippets.

@dasch88
Forked from ondravondra/EFExtensions.cs
Last active September 4, 2018 01:02
Show Gist options
  • Save dasch88/c9d825048f958c7758ef69b09959a180 to your computer and use it in GitHub Desktop.
Save dasch88/c9d825048f958c7758ef69b09959a180 to your computer and use it in GitHub Desktop.
C# extension for executing upsert (MERGE SQL command) in EF with MSSQL. Automatically retrieves keys from entity, and combines script into single statement for IEnumerable.
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations.Schema;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
namespace EFExtentions {
public static class UpsertExtension {
public static EntityOp<TEntity> Upsert<TEntity>(this DbContext context, IEnumerable<TEntity> entity) where TEntity : class {
return new UpsertOp<TEntity>(context, entity);
}
public static EntityOp<TEntity> Upsert<TEntity>(this DbContext context, IEnumerable<TEntity> entity, Expression<Func<TEntity, DateTime>> dateVersionField) where TEntity : class {
return new UpsertOp<TEntity>(context, entity, dateVersionField);
}
}
public abstract class EntityOp<TEntity, TRet> {
public readonly DbContext Context;
public readonly IEnumerable<TEntity> EntityList;
public readonly string TableName;
private readonly List<string> keyNames = new List<string>();
public IEnumerable<string> KeyNames { get { return keyNames; } }
private readonly List<string> excludeProperties = new List<string>();
protected static string GetMemberName<T>(Expression<Func<TEntity, T>> selectMemberLambda) {
var member = selectMemberLambda.Body as MemberExpression;
if (member == null) {
throw new ArgumentException("The parameter selectMemberLambda must be a member accessing labda such as x => x.Id", "selectMemberLambda");
}
return member.Member.Name;
}
public EntityOp(DbContext context, IEnumerable<TEntity> entityList) {
Context = context;
EntityList = entityList;
object[] mappingAttrs = typeof(TEntity).GetCustomAttributes(typeof(TableAttribute), false);
TableAttribute tableAttr = null;
if (mappingAttrs.Length > 0) {
tableAttr = mappingAttrs[0] as TableAttribute;
}
if (tableAttr == null)
throw new ArgumentException("TEntity is missing TableAttribute", "entityList");
TableName = tableAttr.Name;
foreach(var p in typeof(TEntity).GetProperties()) {
object keyAttr = p.GetCustomAttributes(typeof(KeyAttribute), false).FirstOrDefault();
if (keyAttr != null)
keyNames.Add(p.Name);
}
if (tableAttr == null)
throw new ArgumentException("TEntity is missing KeyAttribute(s)", "entityList");
}
public abstract Task<TRet> ExecuteAsync();
public async Task RunAsync() {
await ExecuteAsync();
}
public EntityOp<TEntity, TRet> ExcludeField<TField>(Expression<Func<TEntity, TField>> selectField) {
excludeProperties.Add(GetMemberName(selectField));
return this;
}
public IEnumerable<PropertyInfo> ColumnProperties {
get {
return typeof(TEntity).GetProperties().Where(pr => !excludeProperties.Contains(pr.Name));
}
}
}
public abstract class EntityOp<TEntity> : EntityOp<TEntity, int> {
public EntityOp(DbContext context, IEnumerable<TEntity> entityList) : base(context, entityList) { }
public sealed override async Task<int> ExecuteAsync() {
await ExecuteNoReturnAsync();
return 0;
}
protected abstract Task ExecuteNoReturnAsync();
}
public class UpsertOp<TEntity> : EntityOp<TEntity> {
private Expression<Func<TEntity, DateTime>> _dateVersionField = null;
public UpsertOp(DbContext context, IEnumerable<TEntity> entityList) : base(context, entityList) { }
public UpsertOp(DbContext context, IEnumerable<TEntity> entityList, Expression<Func<TEntity, DateTime>> dateVersionField) : base(context, entityList) {
_dateVersionField = dateVersionField;
}
protected override async Task ExecuteNoReturnAsync() {
StringBuilder sb = new StringBuilder();
var parameters = new List<object>();
var dateVersionColumnProperty = (_dateVersionField != null ? typeof(TEntity).GetProperties().FirstOrDefault(p => p.Name == GetMemberName(_dateVersionField)) : null);
var columnNames = ColumnProperties.Select(p => p.Name).ToArray();
foreach(var entity in EntityList) {
sb.Append("merge into ");
sb.Append(TableName);
sb.Append(" as T ");
sb.Append("using (values (");
appendMergeValuesAndParameters(entity, sb, parameters);
sb.Append(")) as S (");
sb.Append(string.Join(",", columnNames));
sb.Append(") ");
sb.Append("on (");
var mergeCond = string.Join(" and ", KeyNames.Select(kn => "T." + kn + "=S." + kn));
sb.Append(mergeCond);
sb.Append(") ");
sb.Append("when matched");
if(_dateVersionField != null && dateVersionColumnProperty != null) {
sb.Append(" and {");
sb.Append(parameters.Count);
sb.Append("} > T.");
sb.Append(dateVersionColumnProperty.Name);
var dateVersionValue = dateVersionColumnProperty.GetValue(entity, null);
parameters.Add(dateVersionValue);
}
sb.Append(" then update set ");
sb.Append(string.Join(",", columnNames.Select(c => "T." + c + "=S." + c).ToArray()));
sb.Append(" when not matched then insert (");
sb.Append(string.Join(",", columnNames));
sb.Append(") values (S.");
sb.Append(string.Join(",S.", columnNames));
sb.Append(");");
sb.AppendLine();
}
var command = sb.ToString();
await Context.Database.ExecuteSqlCommandAsync(command, parameters.ToArray());
}
private void appendMergeValuesAndParameters(TEntity entity, StringBuilder sb, List<object> parameters) {
foreach (var p in ColumnProperties) {
sb.Append("{");
sb.Append(parameters.Count);
sb.Append("},");
var val = p.GetValue(entity, null);
parameters.Add(val ?? DBNull.Value);
}
sb.Remove(sb.Length - 1, 1);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment