Skip to content

Instantly share code, notes, and snippets.

Last active January 2, 2016 11:09
Show Gist options
  • Save knunery/8294501 to your computer and use it in GitHub Desktop.
Save knunery/8294501 to your computer and use it in GitHub Desktop.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using NHibernate;
using NHibernate.Collection;
using NHibernate.DebugHelpers;
using NHibernate.Engine;
using NHibernate.Loader;
using NHibernate.Persister.Collection;
using NHibernate.Type;
using NHibernate.Util;
namespace NHibernateSample
//add to your configuration:
// = typeof(Net4CollectionTypeFactory).AssemblyQualifiedName;
public class Net4CollectionTypeFactory : DefaultCollectionTypeFactory
public override CollectionType Set<T>(string role, string propertyRef, bool embedded)
return new GenericSetType<T>(role, propertyRef);
public override CollectionType SortedSet<T>(string role, string propertyRef, bool embedded, IComparer<T> comparer)
return new GenericSortedSetType<T>(role, propertyRef, comparer);
public class GenericSortedSetType<T> : GenericSetType<T>
private readonly IComparer<T> comparer;
public GenericSortedSetType(string role, string propertyRef, IComparer<T> comparer)
: base(role, propertyRef)
this.comparer = comparer;
public override object Instantiate(int anticipatedSize)
return new SortedSet<T>(this.comparer);
public IComparer<T> Comparer
return this.comparer;
/// <summary>
/// An <see cref="IType"/> that maps an <see cref="ISet{T}"/> collection
/// to the database.
/// </summary>
public class GenericSetType<T> : SetType
/// <summary>
/// Initializes a new instance of a <see cref="GenericSetType{T}"/> class for
/// a specific role.
/// </summary>
/// <param name="role">The role the persistent collection is in.</param>
/// <param name="propertyRef">The name of the property in the
/// owner object containing the collection ID, or <see langword="null" /> if it is
/// the primary key.</param>
public GenericSetType(string role, string propertyRef)
: base(role, propertyRef, false)
public override Type ReturnedClass
get { return typeof (ISet<T>); }
/// <summary>
/// Instantiates a new <see cref="IPersistentCollection"/> for the set.
/// </summary>
/// <param name="session">The current <see cref="ISessionImplementor"/> for the set.</param>
/// <param name="persister">The current <see cref="ICollectionPersister" /> for the set.</param>
/// <param name="key"></param>
public override IPersistentCollection Instantiate(ISessionImplementor session, ICollectionPersister persister,
object key)
return new PersistentGenericSet<T>(session);
/// <summary>
/// Wraps an <see cref="IList{T}"/> in a <see cref="PersistentGenericSet&lt;T&gt;"/>.
/// </summary>
/// <param name="session">The <see cref="ISessionImplementor"/> for the collection to be a part of.</param>
/// <param name="collection">The unwrapped <see cref="IList{T}"/>.</param>
/// <returns>
/// An <see cref="PersistentGenericSet&lt;T&gt;"/> that wraps the non NHibernate <see cref="IList{T}"/>.
/// </returns>
public override IPersistentCollection Wrap(ISessionImplementor session, object collection)
var set = collection as ISet<T>;
if (set == null)
var stronglyTypedCollection = collection as ICollection<T>;
if (stronglyTypedCollection == null)
throw new HibernateException(Role + " must be an implementation of ISet<T> or ICollection<T>");
set = new HashSet<T>(stronglyTypedCollection);
return new PersistentGenericSet<T>(session, set);
public override object Instantiate(int anticipatedSize)
return new HashSet<T>();
protected override void Clear(object collection)
protected override void Add(object collection, object element)
/// <summary>
/// A persistent wrapper for an <see cref="ISet{T}"/>
/// </summary>
[DebuggerTypeProxy(typeof (CollectionProxy<>))]
public class PersistentGenericSet<T> : AbstractPersistentCollection, ISet<T>
/// <summary>
/// The <see cref="ISet{T}"/> that NHibernate is wrapping.
/// </summary>
protected ISet<T> set;
/// <summary>
/// A temporary list that holds the objects while the PersistentSet is being
/// populated from the database.
/// </summary>
/// <remarks>
/// This is necessary to ensure that the object being added to the PersistentSet doesn't
/// have its' <c>GetHashCode()</c> and <c>Equals()</c> methods called during the load
/// process.
/// </remarks>
[NonSerialized] private IList<T> tempList;
public PersistentGenericSet()
// needed for serialization
/// <summary>
/// Constructor matching super.
/// Instantiates a lazy set (the underlying set is un-initialized).
/// </summary>
/// <param name="session">The session to which this set will belong. </param>
public PersistentGenericSet(ISessionImplementor session) : base(session)
/// <summary>
/// Instantiates a non-lazy set (the underlying set is constructed
/// from the incoming set reference).
/// </summary>
/// <param name="session">The session to which this set will belong. </param>
/// <param name="original">The underlying set data. </param>
public PersistentGenericSet(ISessionImplementor session, ISet<T> original)
: base(session)
// Sets can be just a view of a part of another collection.
// do we need to copy it to be sure it won't be changing
// underneath us?
// ie. this.set.addAll(set);
set = original;
IsDirectlyAccessible = true;
public override bool RowUpdatePossible
get { return false; }
public override bool Empty
get { return set.Count == 0; }
public bool IsEmpty
get { return ReadSize() ? CachedSize == 0 : (set.Count == 0); }
public object SyncRoot
get { return this; }
public bool IsSynchronized
get { return false; }
#region ISet<T> Members
IEnumerator<T> IEnumerable<T>.GetEnumerator()
return set.GetEnumerator();
public bool Contains(T o)
bool? exists = ReadElementExistence(o);
return exists == null ? set.Contains(o) : exists.Value;
public void CopyTo(T[] array, int arrayIndex)
Array.Copy(set.ToArray(), 0, array, arrayIndex, Count);
//public bool ContainsAll(ICollection c)
// Read();
// return set.ContainsAll(c);
public bool Add(T o)
bool? exists = IsOperationQueueEnabled ? ReadElementExistence(o) : null;
if (!exists.HasValue)
if (set.Add(o))
return true;
return false;
if (exists.Value)
return false;
QueueOperation(new SimpleAddDelayedOperation(this, o));
return true;
public void UnionWith(IEnumerable<T> other)
public void IntersectWith(IEnumerable<T> other)
public void ExceptWith(IEnumerable<T> other)
public void SymmetricExceptWith(IEnumerable<T> other)
public bool IsSubsetOf(IEnumerable<T> other)
return set.IsProperSupersetOf(other);
public bool IsSupersetOf(IEnumerable<T> other)
return set.IsSupersetOf(other);
public bool IsProperSupersetOf(IEnumerable<T> other)
return set.IsProperSupersetOf(other);
public bool IsProperSubsetOf(IEnumerable<T> other)
return set.IsProperSubsetOf(other);
public bool Overlaps(IEnumerable<T> other)
return set.Overlaps(other);
public bool SetEquals(IEnumerable<T> other)
return set.SetEquals(other);
public bool Remove(T o)
bool? exists = PutQueueEnabled ? ReadElementExistence(o) : null;
if (!exists.HasValue)
if (set.Remove(o))
return true;
return false;
if (exists.Value)
QueueOperation(new SimpleRemoveDelayedOperation(this, o));
return true;
return false;
void ICollection<T>.Add(T item)
public void Clear()
if (ClearQueueEnabled)
QueueOperation(new ClearDelayedOperation(this));
if (set.Count != 0)
public int Count
get { return ReadSize() ? CachedSize : set.Count; }
public bool IsReadOnly
get { return false; }
public IEnumerator GetEnumerator()
return set.GetEnumerator();
#region DelayedOperations
#region Nested type: ClearDelayedOperation
protected sealed class ClearDelayedOperation : IDelayedOperation
private readonly PersistentGenericSet<T> enclosingInstance;
public ClearDelayedOperation(PersistentGenericSet<T> enclosingInstance)
this.enclosingInstance = enclosingInstance;
#region IDelayedOperation Members
public object AddedInstance
get { return null; }
public object Orphan
get { throw new NotSupportedException("queued clear cannot be used with orphan delete"); }
public void Operate()
#region Nested type: SimpleAddDelayedOperation
protected sealed class SimpleAddDelayedOperation : IDelayedOperation
private readonly PersistentGenericSet<T> enclosingInstance;
private readonly T value;
public SimpleAddDelayedOperation(PersistentGenericSet<T> enclosingInstance, T value)
this.enclosingInstance = enclosingInstance;
this.value = value;
#region IDelayedOperation Members
public object AddedInstance
get { return value; }
public object Orphan
get { return null; }
public void Operate()
#region Nested type: SimpleRemoveDelayedOperation
protected sealed class SimpleRemoveDelayedOperation : IDelayedOperation
private readonly PersistentGenericSet<T> enclosingInstance;
private readonly T value;
public SimpleRemoveDelayedOperation(PersistentGenericSet<T> enclosingInstance, T value)
this.enclosingInstance = enclosingInstance;
this.value = value;
#region IDelayedOperation Members
public object AddedInstance
get { return null; }
public object Orphan
get { return value; }
public void Operate()
public override ICollection GetSnapshot(ICollectionPersister persister)
var entityMode = Session.EntityMode;
var clonedSet = new SetSnapShot<T>(set.Count);
var enumerable = from object current in set
select persister.ElementType.DeepCopy(current, entityMode, persister.Factory);
foreach (var copied in enumerable)
return clonedSet;
public override ICollection GetOrphans(object snapshot, string entityName)
var sn = new SetSnapShot<object>((IEnumerable<object>) snapshot);
if (set.Count == 0) return sn;
if (((ICollection) sn).Count == 0) return sn;
return GetOrphans(sn, set.ToArray(), entityName, Session);
public override bool EqualsSnapshot(ICollectionPersister persister)
var elementType = persister.ElementType;
var snapshot = (ISetSnapshot<T>) GetSnapshot();
if (((ICollection) snapshot).Count != set.Count)
return false;
return !(from object obj in set
let oldValue = snapshot[(T)obj]
where oldValue == null || elementType.IsDirty(oldValue, obj, Session)
select obj).Any();
public override bool IsSnapshotEmpty(object snapshot)
return ((ICollection) snapshot).Count == 0;
public override void BeforeInitialize(ICollectionPersister persister, int anticipatedSize)
set = (ISet<T>) persister.CollectionType.Instantiate(anticipatedSize);
/// <summary>
/// Initializes this PersistentSet from the cached values.
/// </summary>
/// <param name="persister">The CollectionPersister to use to reassemble the PersistentSet.</param>
/// <param name="disassembled">The disassembled PersistentSet.</param>
/// <param name="owner">The owner object.</param>
public override void InitializeFromCache(ICollectionPersister persister, object disassembled, object owner)
var array = (object[]) disassembled;
int size = array.Length;
BeforeInitialize(persister, size);
for (int i = 0; i < size; i++)
var element = (T) persister.ElementType.Assemble(array[i], Session, owner);
if (element != null)
public override string ToString()
return StringHelper.CollectionToString(set.ToArray());
public override object ReadFrom(IDataReader rs, ICollectionPersister role, ICollectionAliases descriptor, object owner)
var element = (T) role.ReadElement(rs, owner, descriptor.SuffixedElementAliases, Session);
if (element != null)
return element;
/// <summary>
/// Set up the temporary List that will be used in the EndRead()
/// to fully create the set.
/// </summary>
public override void BeginRead()
tempList = new List<T>();
/// <summary>
/// Takes the contents stored in the temporary list created during <c>BeginRead()</c>
/// that was populated during <c>ReadFrom()</c> and write it to the underlying
/// PersistentSet.
/// </summary>
public override bool EndRead(ICollectionPersister persister)
foreach (T item in tempList)
tempList = null;
return true;
public override IEnumerable Entries(ICollectionPersister persister)
return set;
public override object Disassemble(ICollectionPersister persister)
var result = new object[set.Count];
int i = 0;
foreach (object obj in set)
result[i++] = persister.ElementType.Disassemble(obj, Session, null);
return result;
public override IEnumerable GetDeletes(ICollectionPersister persister, bool indexIsFormula)
IType elementType = persister.ElementType;
var sn = (ISetSnapshot<T>) GetSnapshot();
var deletes = new List<T>(((ICollection<T>) sn).Count);
deletes.AddRange(sn.Where(obj => !set.Contains(obj)));
deletes.AddRange(from obj in set
let oldValue = sn[obj]
where oldValue != null && elementType.IsDirty(obj, oldValue, Session)
select oldValue);
return deletes;
public override bool NeedsInserting(object entry, int i, IType elemType)
var sn = (ISetSnapshot<T>) GetSnapshot();
object oldKey = sn[(T)entry];
// note that it might be better to iterate the snapshot but this is safe,
// assuming the user implements equals() properly, as required by the PersistentSet
// contract!
return oldKey == null || elemType.IsDirty(oldKey, entry, Session);
public override bool NeedsUpdating(object entry, int i, IType elemType)
return false;
public override object GetIndex(object entry, int i, ICollectionPersister persister)
throw new NotSupportedException("Sets don't have indexes");
public override object GetElement(object entry)
return entry;
public override object GetSnapshotElement(object entry, int i)
throw new NotSupportedException("Sets don't support updating by element");
public new void Read()
public override bool Equals(object other)
var that = other as ISet<T>;
if (that == null)
return false;
return set.SequenceEqual(that);
public override int GetHashCode()
return set.GetHashCode();
public override bool EntryExists(object entry, int i)
return true;
public override bool IsWrapper(object collection)
return set == collection;
public void CopyTo(Array array, int index)
// NH : we really need to initialize the set ?
Array.Copy(set.ToArray(), 0, array, index, Count);
#region Nested type: ISetSnapshot
private interface ISetSnapshot<T> : ICollection<T>, ICollection
T this[T element] { get; }
#region Nested type: SetSnapShot
private class SetSnapShot<T> : ISetSnapshot<T>
private readonly List<T> elements;
private SetSnapShot()
elements = new List<T>();
public SetSnapShot(int capacity)
elements = new List<T>(capacity);
public SetSnapShot(IEnumerable<T> collection)
elements = new List<T>(collection);
#region ISetSnapshot<T> Members
public IEnumerator<T> GetEnumerator()
return elements.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator()
return GetEnumerator();
public void Add(T item)
public void Clear()
throw new InvalidOperationException();
public bool Contains(T item)
return elements.Contains(item);
public void CopyTo(T[] array, int arrayIndex)
elements.CopyTo(array, arrayIndex);
public bool Remove(T item)
throw new InvalidOperationException();
public void CopyTo(Array array, int index)
((ICollection) elements).CopyTo(array, index);
int ICollection.Count
get { return elements.Count; }
public object SyncRoot
get { return ((ICollection) elements).SyncRoot; }
public bool IsSynchronized
get { return ((ICollection) elements).IsSynchronized; }
int ICollection<T>.Count
get { return elements.Count; }
public bool IsReadOnly
get { return ((ICollection<T>) elements).IsReadOnly; }
public T this[T element]
int idx = elements.IndexOf(element);
if (idx >= 0)
return elements[idx];
return default(T);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment