Skip to content

Instantly share code, notes, and snippets.

@neilgaietto
Forked from ondravondra/EFExtensions.cs
Last active July 18, 2018 20:36
Show Gist options
  • Save neilgaietto/a337095565301c6e55d1b87bb4f5e04a to your computer and use it in GitHub Desktop.
Save neilgaietto/a337095565301c6e55d1b87bb4f5e04a to your computer and use it in GitHub Desktop.
C# extension for executing Async upsert (MERGE SQL command) in EF with MSSQL. Multi Column keys support
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
namespace EFExtensions
{
public static class DbContextExtensions
{
public static EntityOp<TEntity> Upsert<TEntity>(this DbContext context, TEntity entity) where TEntity : class
{
return new UpsertOp<TEntity>(context, entity);
}
}
public abstract class EntityOp<TEntity, TRet>
{
public readonly DbContext Context;
public readonly TEntity Entity;
public readonly string TableName;
private readonly List<string> keyNames = new List<string>();
public IEnumerable<string> KeyNames { get { return keyNames; } }
private readonly List<string> excludeProperties = new List<string>();
private static List<string> GetMemberNames<T>(Expression<Func<TEntity, T>> selectMemberLambda)
{
var member = selectMemberLambda.Body as MemberExpression;
if (member != null)
{
return new List<string> { member.Member.Name };
}
var newMember = selectMemberLambda.Body as NewExpression;
if (newMember != null)
{
return newMember.Members.Select(x => x.Name).ToList();
}
throw new ArgumentException("The parameter selectMemberLambda must be a member accessing labda such as x => x.Id", "selectMemberLambda");
}
public EntityOp(DbContext context, TEntity entity)
{
Context = context;
Entity = entity;
object[] mappingAttrs = typeof(TEntity).GetCustomAttributes(typeof(TableAttribute), false);
TableAttribute tableAttr = null;
if (mappingAttrs.Length > 0)
{
tableAttr = mappingAttrs[0] as TableAttribute;
}
if (tableAttr == null)
{
throw new ArgumentException("TEntity is missing TableAttribute", "entity");
}
TableName = tableAttr.Name;
}
public abstract Task<TRet> Execute();
public async Task Run()
{
await Execute();
}
public EntityOp<TEntity, TRet> Key<TKey>(Expression<Func<TEntity, TKey>> selectKey)
{
keyNames.AddRange(GetMemberNames(selectKey));
return this;
}
public EntityOp<TEntity, TRet> ExcludeField<TField>(Expression<Func<TEntity, TField>> selectField)
{
excludeProperties.AddRange(GetMemberNames(selectField));
return this;
}
public IEnumerable<PropertyInfo> ColumnProperties
{
get
{
return typeof(TEntity).GetProperties().Where(pr => !excludeProperties.Contains(pr.Name) && !pr.PropertyType.IsClass);
}
}
}
public abstract class EntityOp<TEntity> : EntityOp<TEntity, int>
{
public EntityOp(DbContext context, TEntity entity) : base(context, entity) { }
public sealed override async Task<int> Execute()
{
await ExecuteNoRet();
return 0;
}
protected abstract Task ExecuteNoRet();
}
public class UpsertOp<TEntity> : EntityOp<TEntity>
{
public UpsertOp(DbContext context, TEntity entity) : base(context, entity) { }
protected override async Task ExecuteNoRet()
{
StringBuilder sql = new StringBuilder();
int notNullFields = 0;
var valueKeyList = new List<string>();
var columnList = new List<string>();
var valueList = new List<object>();
foreach (var p in ColumnProperties)
{
columnList.Add(p.Name);
var val = p.GetValue(Entity, null);
if (val != null)
{
valueKeyList.Add("{" + (notNullFields++) + "}");
valueList.Add(val);
}
else
{
valueKeyList.Add("null");
}
}
var columns = columnList.Select(x => "[" + x + "]").ToArray();
sql.Append("merge into [");
sql.Append(TableName);
sql.Append("] as T ");
sql.Append("using (values (");
sql.Append(string.Join(",", valueKeyList.ToArray()));
sql.Append(")) as S (");
sql.Append(string.Join(",", columns));
sql.Append(") ");
sql.Append("on (");
var mergeCond = string.Join(" and ", KeyNames.Select(kn => "T." + kn + "=S." + kn));
sql.Append(mergeCond);
sql.Append(") ");
sql.Append("when matched then update set ");
sql.Append(string.Join(",", columns.Select(c => "T." + c + "=S." + c).ToArray()));
sql.Append(" when not matched then insert (");
sql.Append(string.Join(",", columns));
sql.Append(") values (S.");
sql.Append(string.Join(",S.", columns));
sql.Append(");");
await Context.Database.ExecuteSqlCommandAsync(sql.ToString(), valueList.ToArray());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment