Skip to content

Instantly share code, notes, and snippets.

@mjs3339
Last active June 5, 2018 00:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mjs3339/efbb60db2ecae4680f1e57637d196374 to your computer and use it in GitHub Desktop.
Save mjs3339/efbb60db2ecae4680f1e57637d196374 to your computer and use it in GitHub Desktop.
C# ConcurrentHashSet Class, 20% Faster than ConcurrentDictionary with Dummy Value
[Serializable]
public class ConcurrentHashSet<T> : IReadOnlyCollection<T>, ICollection<T>
{
private readonly IEqualityComparer<T> _comparer;
private readonly bool _growLockArray;
private int _budget;
private volatile Tables _tables;
public ConcurrentHashSet()
: this(DefaultConcurrencyLevel, 16, true, null)
{
}
public ConcurrentHashSet(int capacity)
: this(DefaultConcurrencyLevel, capacity, false, null)
{
}
public ConcurrentHashSet(int concurrencyLevel, int capacity)
: this(concurrencyLevel, capacity, false, null)
{
}
public ConcurrentHashSet(IEnumerable<T> collection)
: this(collection, null)
{
}
public ConcurrentHashSet(IEqualityComparer<T> comparer)
: this(DefaultConcurrencyLevel, 16, true, comparer)
{
}
public ConcurrentHashSet(IEnumerable<T> collection, IEqualityComparer<T> comparer)
: this(comparer)
{
if (collection == null)
throw new Exception("Collection is null.");
InitializeFromCollection(collection);
}
public ConcurrentHashSet(int concurrencyLevel, IEnumerable<T> collection, IEqualityComparer<T> comparer)
: this(concurrencyLevel, 16, false, comparer)
{
if (collection == null)
throw new Exception("Collection is null.");
InitializeFromCollection(collection);
}
public ConcurrentHashSet(int concurrencyLevel, int capacity, IEqualityComparer<T> comparer)
: this(concurrencyLevel, capacity, false, comparer)
{
}
private ConcurrentHashSet(int concurrencyLevel, int capacity, bool growLockArray, IEqualityComparer<T> comparer)
{
if (concurrencyLevel < 1)
throw new Exception("Concurrency Level needs to be 1 or higher.");
if (capacity < 0)
throw new Exception("Capacity needs to be a positive number");
if (capacity < concurrencyLevel)
capacity = concurrencyLevel;
var locks = new object[concurrencyLevel];
for (var i = 0; i < locks.Length; i++)
locks[i] = new object();
var countPerLock = new int[locks.Length];
var buckets = new Node[capacity];
_tables = new Tables(buckets, locks, countPerLock);
_growLockArray = growLockArray;
_budget = buckets.Length / locks.Length;
_comparer = comparer ?? EqualityComparer<T>.Default;
}
private static int DefaultConcurrencyLevel => Environment.ProcessorCount;
public bool IsEmpty
{
get
{
var acquiredLocks = 0;
try
{
AcquireAllLocks(ref acquiredLocks);
if (_tables.CountPerLock.Any(t => t != 0))
return false;
}
finally
{
ReleaseLocks(0, acquiredLocks);
}
return true;
}
}
public T this[int index]
{
get
{
if (index > Count)
throw new Exception($"Getter: Index out of bounds {index} must be less than {Count}");
return _tables.Buckets[index].Item;
}
}
public void Clear()
{
var locksAcquired = 0;
try
{
AcquireAllLocks(ref locksAcquired);
var nTables = new Tables(new Node[16], _tables.Locks, new int[_tables.CountPerLock.Length]);
_tables = nTables;
_budget = Math.Max(1, nTables.Buckets.Length / nTables.Locks.Length);
}
finally
{
ReleaseLocks(0, locksAcquired);
}
}
public bool Contains(T item)
{
if (item == null) return false;
var hashcode = _comparer.GetHashCode(item);
var tables = _tables;
var bucketNumber = (hashcode & int.MaxValue) % tables.Buckets.Length;
var current = Volatile.Read(ref tables.Buckets[bucketNumber]);
while (current != null)
{
if (hashcode == current.Hashcode && _comparer.Equals(current.Item, item))
return true;
current = current.Next;
}
return false;
}
void ICollection<T>.Add(T item)
{
Add(item);
}
bool ICollection<T>.IsReadOnly => false;
void ICollection<T>.CopyTo(T[] array, int arrayIndex)
{
if (array == null)
throw new Exception("The array is null.");
if (arrayIndex < 0)
throw new Exception("The array index is out of range.");
var locksAcquired = 0;
try
{
AcquireAllLocks(ref locksAcquired);
var count = 0;
for (var i = 0; i < _tables.Locks.Length && count >= 0; i++)
count += _tables.CountPerLock[i];
if (array.Length - count < arrayIndex || count < 0)
throw new Exception("The index is equal to or greater than the length of the array.");
CopyTo(array, arrayIndex);
}
finally
{
ReleaseLocks(0, locksAcquired);
}
}
bool ICollection<T>.Remove(T item)
{
return TryRemove(item);
}
public int Count
{
get
{
var count = 0;
var acquiredLocks = 0;
try
{
AcquireAllLocks(ref acquiredLocks);
count += _tables.CountPerLock.Sum();
}
finally
{
ReleaseLocks(0, acquiredLocks);
}
return count;
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public IEnumerator<T> GetEnumerator()
{
var buckets = _tables.Buckets;
for (var i = 0; i < buckets.Length; i++)
{
var current = Volatile.Read(ref buckets[i]);
while (current != null)
{
yield return current.Item;
current = current.Next;
}
}
}
private void InitializeFromCollection(IEnumerable<T> collection)
{
foreach (var item in collection)
{
if (item == null)
throw new Exception("Item in collection is null.");
if (!AddInternal(item, false))
throw new Exception("Source Contains Duplicates.");
}
if (_budget != 0)
return;
_budget = _tables.Buckets.Length / _tables.Locks.Length;
}
public bool Add(T item)
{
return AddInternal(item, true);
}
public bool TryRemove(T item)
{
var hashcode = _comparer.GetHashCode(item);
while (true)
{
var tables = _tables;
GetBucketAndlockNumber(hashcode, out var bucketNumber, out var lockNumber, tables.Buckets.Length,
tables.Locks.Length);
lock (tables.Locks[lockNumber])
{
if (tables != _tables)
continue;
Node previous = null;
for (var current = tables.Buckets[bucketNumber]; current != null; current = current.Next)
{
if (hashcode == current.Hashcode && _comparer.Equals(current.Item, item))
{
if (previous == null)
Volatile.Write(ref tables.Buckets[bucketNumber], current.Next);
else
previous.Next = current.Next;
tables.CountPerLock[lockNumber]--;
return true;
}
previous = current;
}
}
return false;
}
}
private bool AddInternal(T item, bool acquireLock)
{
var hashcode = _comparer.GetHashCode(item);
while (true)
{
var tables = _tables;
GetBucketAndlockNumber(hashcode, out var bucketNumber, out var lockNumber, tables.Buckets.Length,
tables.Locks.Length);
var resize = false;
var lockTaken = false;
try
{
if (acquireLock)
Monitor.Enter(tables.Locks[lockNumber], ref lockTaken);
if (tables != _tables)
continue;
for (var node = tables.Buckets[bucketNumber]; node != null; node = node.Next)
if (hashcode == node.Hashcode && _comparer.Equals(node.Item, item))
return false;
Volatile.Write(ref tables.Buckets[bucketNumber], new Node(item, hashcode, tables.Buckets[bucketNumber]));
checked
{
tables.CountPerLock[lockNumber]++;
}
if (tables.CountPerLock[lockNumber] > _budget)
resize = true;
}
finally
{
if (lockTaken)
Monitor.Exit(tables.Locks[lockNumber]);
}
if (resize)
GrowTable(tables);
return true;
}
}
private static void GetBucketAndlockNumber(int hashcode, out int bucketNumber, out int lockNumber, int bucketCount,
int lockCount)
{
bucketNumber = (hashcode & int.MaxValue) % bucketCount;
lockNumber = bucketNumber % lockCount;
}
private void GrowTable(Tables tables)
{
var locksAcquired = 0;
try
{
AcquireLocks(0, 1, ref locksAcquired);
if (tables != _tables)
return;
var approxCount = tables.CountPerLock.Aggregate<int, long>(0, (current, t) => current + t);
if (approxCount < tables.Buckets.Length / 4)
{
_budget = 2 * _budget;
if (_budget < 0)
_budget = int.MaxValue;
return;
}
var nLength = 0;
var maxTableSize = false;
try
{
checked
{
nLength = tables.Buckets.Length * 2 + 1;
while (nLength % 3 == 0 || nLength % 5 == 0 || nLength % 7 == 0)
nLength += 2;
if (nLength > int.MaxValue - 0x100000)
maxTableSize = true;
}
}
catch (OverflowException)
{
maxTableSize = true;
}
if (maxTableSize)
{
nLength = int.MaxValue - 0x100000;
_budget = int.MaxValue;
}
AcquireLocks(1, tables.Locks.Length, ref locksAcquired);
var nLocks = tables.Locks;
if (_growLockArray && tables.Locks.Length < 1024)
{
nLocks = new object[tables.Locks.Length * 2];
Array.Copy(tables.Locks, 0, nLocks, 0, tables.Locks.Length);
for (var i = tables.Locks.Length; i < nLocks.Length; i++)
nLocks[i] = new object();
}
var nBuckets = new Node[nLength];
var nCountPerLock = new int[nLocks.Length];
foreach (var t in tables.Buckets)
{
var current = t;
while (current != null)
{
var next = current.Next;
GetBucketAndlockNumber(current.Hashcode, out var nbucketNumber, out var newlockNumber, nBuckets.Length,
nLocks.Length);
nBuckets[nbucketNumber] = new Node(current.Item, current.Hashcode, nBuckets[nbucketNumber]);
checked
{
nCountPerLock[newlockNumber]++;
}
current = next;
}
}
_budget = Math.Max(1, nBuckets.Length / nLocks.Length);
_tables = new Tables(nBuckets, nLocks, nCountPerLock);
}
finally
{
ReleaseLocks(0, locksAcquired);
}
}
private void AcquireAllLocks(ref int locksAcquired)
{
AcquireLocks(0, 1, ref locksAcquired);
AcquireLocks(1, _tables.Locks.Length, ref locksAcquired);
}
private void AcquireLocks(int fromInclusive, int toExclusive, ref int locksAcquired)
{
var locks = _tables.Locks;
for (var i = fromInclusive; i < toExclusive; i++)
{
var lockTaken = false;
try
{
Monitor.Enter(locks[i], ref lockTaken);
}
finally
{
if (lockTaken)
locksAcquired++;
}
}
}
private void ReleaseLocks(int fromInclusive, int toExclusive)
{
for (var i = fromInclusive; i < toExclusive; i++)
Monitor.Exit(_tables.Locks[i]);
}
private void CopyTo(T[] array, int index)
{
var buckets = _tables.Buckets;
foreach (var t in buckets)
for (var current = t; current != null; current = current.Next)
{
array[index] = current.Item;
index++;
}
}
public T[] ToArray()
{
var locksAcquired = 0;
try
{
AcquireAllLocks(ref locksAcquired);
var length = 0;
var index = 0;
while (index < _tables.Locks.Length)
{
checked
{
length += _tables.CountPerLock[index];
}
checked
{
++index;
}
}
if (length == 0)
return Array.Empty<T>();
var array = new T[length];
CopyTo(array, 0);
return array;
}
finally
{
ReleaseLocks(0, locksAcquired);
}
}
private class Tables
{
public readonly Node[] Buckets;
public readonly object[] Locks;
public volatile int[] CountPerLock;
public Tables(Node[] buckets, object[] locks, int[] countPerLock)
{
Buckets = buckets;
Locks = locks;
CountPerLock = countPerLock;
}
}
private class Node
{
public readonly int Hashcode;
public readonly T Item;
public volatile Node Next;
public Node(T item, int hashcode, Node next)
{
Item = item;
Hashcode = hashcode;
Next = next;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment