Skip to content

Instantly share code, notes, and snippets.

@peder
Last active March 7, 2017 23:56
Show Gist options
  • Save peder/145988296772e70e2f8f8d16d4dcafcb to your computer and use it in GitHub Desktop.
Save peder/145988296772e70e2f8f8d16d4dcafcb to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Web;
namespace Peder.ExtensionMethods
{
public static class EntityFrameworkApiExtensionMethods
{
public static void DiscardChanges(this DbContext dbContext)
{
RejectScalarChanges(dbContext);
RejectNavigationChanges(dbContext);
}
private static void RejectScalarChanges(DbContext dbContext)
{
foreach (var entry in dbContext.ChangeTracker.Entries())
{
switch (entry.State)
{
case EntityState.Modified:
case EntityState.Deleted:
entry.State = EntityState.Modified; //Revert changes made to deleted entity.
entry.State = EntityState.Unchanged;
break;
case EntityState.Added:
entry.State = EntityState.Detached;
break;
}
}
}
private static void RejectNavigationChanges(DbContext dbContext)
{
var objectContext = ((IObjectContextAdapter)dbContext).ObjectContext;
var deletedRelationships = objectContext.ObjectStateManager.GetObjectStateEntries(EntityState.Deleted).Where(e => e.IsRelationship && !RelationshipContainsKeyEntry(e, dbContext));
var addedRelationships = objectContext.ObjectStateManager.GetObjectStateEntries(EntityState.Added).Where(e => e.IsRelationship);
foreach (var relationship in addedRelationships)
relationship.Delete();
foreach (var relationship in deletedRelationships)
relationship.ChangeState(EntityState.Unchanged);
}
private static bool RelationshipContainsKeyEntry(System.Data.Entity.Core.Objects.ObjectStateEntry stateEntry, DbContext dbContext)
{
//prevent exception: "Cannot change state of a relationship if one of the ends of the relationship is a KeyEntry"
//I haven't been able to find the conditions under which this happens, but it sometimes does.
var objectContext = ((IObjectContextAdapter)dbContext).ObjectContext;
var keys = new[] { stateEntry.OriginalValues[0], stateEntry.OriginalValues[1] };
return keys.Any(key => objectContext.ObjectStateManager.GetObjectStateEntry(key).Entity == null);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment