Skip to content

Instantly share code, notes, and snippets.

@ondravondra
Created November 2, 2012 12:49
Show Gist options
  • Save ondravondra/4001192 to your computer and use it in GitHub Desktop.
Save ondravondra/4001192 to your computer and use it in GitHub Desktop.
C# extension for executing upsert (MERGE SQL command) in EF with MSSQL
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
namespace EFExtensions
{
public static class EFExtensions
{
public static EntityOp<TEntity> Upsert<TEntity>(this DbContext context, TEntity entity) where TEntity : class
{
return new UpsertOp<TEntity>(context, entity);
}
}
public abstract class EntityOp<TEntity, TRet>
{
public readonly DbContext Context;
public readonly TEntity Entity;
public readonly string TableName;
private readonly List<string> keyNames = new List<string>();
public IEnumerable<string> KeyNames { get { return keyNames; } }
private readonly List<string> excludeProperties = new List<string>();
private static string GetMemberName<T>(Expression<Func<TEntity, T>> selectMemberLambda)
{
var member = selectMemberLambda.Body as MemberExpression;
if (member == null)
{
throw new ArgumentException("The parameter selectMemberLambda must be a member accessing labda such as x => x.Id", "selectMemberLambda");
}
return member.Member.Name;
}
public EntityOp(DbContext context, TEntity entity)
{
Context = context;
Entity = entity;
object[] mappingAttrs = typeof(TEntity).GetCustomAttributes(typeof(TableAttribute), false);
TableAttribute tableAttr = null;
if (mappingAttrs.Length > 0)
{
tableAttr = mappingAttrs[0] as TableAttribute;
}
if (tableAttr == null)
{
throw new ArgumentException("TEntity is missing TableAttribute", "entity");
}
TableName = tableAttr.Name;
}
public abstract TRet Execute();
public void Run()
{
Execute();
}
public EntityOp<TEntity, TRet> Key<TKey>(Expression<Func<TEntity, TKey>> selectKey)
{
keyNames.Add(GetMemberName(selectKey));
return this;
}
public EntityOp<TEntity, TRet> ExcludeField<TField>(Expression<Func<TEntity, TField>> selectField)
{
excludeProperties.Add(GetMemberName(selectField));
return this;
}
public IEnumerable<PropertyInfo> ColumnProperties
{
get
{
return typeof(TEntity).GetProperties().Where(pr => !excludeProperties.Contains(pr.Name));
}
}
}
public abstract class EntityOp<TEntity> : EntityOp<TEntity, int>
{
public EntityOp(DbContext context, TEntity entity) : base (context, entity) {}
public sealed override int Execute()
{
ExecuteNoRet();
return 0;
}
protected abstract void ExecuteNoRet();
}
public class UpsertOp<TEntity> : EntityOp<TEntity>
{
public UpsertOp(DbContext context, TEntity entity) : base(context, entity) { }
protected override void ExecuteNoRet()
{
StringBuilder sql = new StringBuilder();
int notNullFields = 0;
var valueKeyList = new List<string>();
var columnList = new List<string>();
var valueList = new List<object>();
foreach (var p in ColumnProperties)
{
columnList.Add(p.Name);
var val = p.GetValue(Entity, null);
if (val != null)
{
valueKeyList.Add("{" + (notNullFields++) + "}");
valueList.Add(val);
}
else
{
valueKeyList.Add("null");
}
}
var columns = columnList.ToArray();
sql.Append("merge into ");
sql.Append(TableName);
sql.Append(" as T ");
sql.Append("using (values (");
sql.Append(string.Join(",", valueKeyList.ToArray()));
sql.Append(")) as S (");
sql.Append(string.Join(",", columns));
sql.Append(") ");
sql.Append("on (");
var mergeCond = string.Join(" and ", KeyNames.Select(kn => "T." + kn + "=S." + kn));
sql.Append(mergeCond);
sql.Append(") ");
sql.Append("when matched then update set ");
sql.Append(string.Join(",", columns.Select(c => "T." + c + "=S." + c).ToArray()));
sql.Append(" when not matched then insert (");
sql.Append(string.Join(",", columns));
sql.Append(") values (S.");
sql.Append(string.Join(",S.", columns));
sql.Append(");");
Context.Database.ExecuteSqlCommand(sql.ToString(), valueList.ToArray());
}
}
}
@mcshaz
Copy link

mcshaz commented May 28, 2017

There are further tweeks in this GitHub file to address a number of problems I encountered:

  • IEnumerables are upserted in a single transaction, rather than 1 transaction per item.
  • column name mapping is handled (i.e. property name is not necessarily the database column name)
  • property name is a reserved sql word (eg join)
  • database properties only (virtual, to me, is to do with how inherited properties are to be handled, and does not imply they are mapped or not)
  • default to upserting on primary key(s)
  • composite keys handled
  • assume update will not alter primary key(s) for the record
  • insert inserts keys which are not database generated
  • all of the above are extracted from the DbContext, and therefore it doesn't matter if using fluent API or Property annotations

@howardhee
Copy link

howardhee commented Jul 14, 2017

I encountered error message "TEntity is missing TableAttribute Parameter name: entity" by running code below. How do I solve it?

var test = new ShopData
{
     Amount = 999999,
     CreateDate = DateTime.Now,
     Product = 11
};

using (var ctx = new ShopEntities())
{
       var op = ctx.Upsert(test);
       op.Execute();
       ctx.SaveChanges();
}

@Arithmomaniac
Copy link

Arithmomaniac commented Nov 20, 2017

@mcshaz Nice. Word of caution: since you can only have 2100 parameters in a SQL statement, you need to chunk the upsert operation accordingly. Something like

public override int Execute()
{
    //you can't have more than 2100 values in a query, so you need to execute the merge in batches
    var batchSize = 2100 / _propNames.Count;
    return _entityList.Batch(batchSize).Sum(Execute); //Batch comes from MoreLinq
}

private int Execute(IEnumerable<TEntity> entities)
{
    //old code, but parameterized
}

(If you want to fit more items at once, pass in the values as one giant XML or JSON string, and unpack it before the merge statement.)

@rickenberg
Copy link

rickenberg commented Nov 9, 2018

I encountered error message "TEntity is missing TableAttribute Parameter name: entity" by running code below. How do I solve it?

var test = new ShopData
{
     Amount = 999999,
     CreateDate = DateTime.Now,
     Product = 11
};

using (var ctx = new ShopEntities())
{
       var op = ctx.Upsert(test);
       op.Execute();
       ctx.SaveChanges();
}

@howardhee - You can solve it by adding a Table attribute to you ShopData class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment