Skip to content

Instantly share code, notes, and snippets.

@Zhentar
Created August 4, 2018 02:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Zhentar/eac2d9078860c29c58575e04fbe1deca to your computer and use it in GitHub Desktop.
Save Zhentar/eac2d9078860c29c58575e04fbe1deca to your computer and use it in GitHub Desktop.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
public abstract class BaseDictionary<TKey, TValue>
{
protected BaseDictionary() { }
protected BaseDictionary(int initialSize) => Initialize((int)Math.Ceiling(Math.Log(initialSize, 2)));
protected interface IKeyHandler<TKey>
{
uint MapKeyToBucket(TKey key);
bool KeysEqual(TKey lhs, TKey rhs);
}
protected struct EquatableKeyHandler : IKeyHandler<TKey>
{
public uint MapKeyToBucket(TKey key) => (uint)EqualityComparer<TKey>.Default.GetHashCode(key);
public bool KeysEqual(TKey lhs, TKey rhs) => EqualityComparer<TKey>.Default.Equals(lhs, rhs);
}
private struct Entry
{
public TKey Key;
public TValue Value;
}
private int _sizePowTwo;
private uint _maxDisplacement;
private uint _mask;
private Entry[] _entries;
protected void Initialize(int powerOf2Size)
{
_sizePowTwo = powerOf2Size;
_maxDisplacement = (uint) powerOf2Size;
_mask = (1u << powerOf2Size) - 1;
_entries = new Entry[(1 << powerOf2Size) + _maxDisplacement + 1];
}
protected uint EntryCount { get; private set; }
protected void RemoveAll() => Initialize(_sizePowTwo);
protected bool RemoveEntry<THandler>(TKey key, THandler keyHandler) where THandler : IKeyHandler<TKey>
{
ref TValue value = ref FindEntry(key, false, out var found, keyHandler);
if(found) { value = default; EntryCount--;}
return found;
}
protected ref TValue FindEntry<THandler>(TKey key, bool insertIfNotFound, out bool found, THandler keyHandler) where THandler : IKeyHandler<TKey>
{
var index = keyHandler.MapKeyToBucket(key);
ref var baseEntry = ref _entries[index & _mask];
var maxDisplacement = (IntPtr) _maxDisplacement;
for (IntPtr j = (IntPtr)0; j != maxDisplacement; j += 1)
{
ref var checkEntry = ref Unsafe.Add(ref baseEntry, j);
if (keyHandler.KeysEqual(checkEntry.Key, key))
{
found = true;
return ref checkEntry.Value;
}
if (EqualityComparer<TValue>.Default.Equals(checkEntry.Value, default))
{
found = false;
return ref (insertIfNotFound ? ref InsertEntry(key, (uint) j, keyHandler) : ref checkEntry.Value);
}
}
found = false;
return ref (insertIfNotFound ? ref InsertEntry(key, null, keyHandler) : ref baseEntry.Value);
}
private static bool _dummy; //needed to satisfy recursive FindEntry call
private ref TValue InsertEntry<THandler>(TKey key, uint? index, THandler keyHandler) where THandler : IKeyHandler<TKey>
{
if(index == null)
{
Resize(keyHandler);
return ref FindEntry(key, true, out _dummy, keyHandler);
}
EntryCount++;
ref Entry entry = ref _entries[index.Value];
entry.Key = key;
return ref entry.Value;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint FindFirstOpenSlot(Entry[] entries, uint expectedPos)
{
for (uint j = expectedPos; j < entries.Length; j++)
{
if (EqualityComparer<TValue>.Default.Equals(entries[j].Value, default)) { return j; }
}
throw new InvalidOperationException();
}
private void Resize<THandler>(THandler keyHandler) where THandler : IKeyHandler<TKey>
{
var oldEntries = _entries;
Initialize(_sizePowTwo + 1);
var mask = _mask;
var entries = _entries;
for (int i = 0; i < oldEntries.Length; i++)
{
var newSlot = FindFirstOpenSlot(entries, keyHandler.MapKeyToBucket(oldEntries[i].Key) & mask);
entries[newSlot] = oldEntries[i];
}
}
public struct Enumerator : IEnumerator<KeyValuePair<TKey, TValue>>
{
private readonly BaseDictionary<TKey, TValue> _parent;
private uint _index;
internal Enumerator(BaseDictionary<TKey, TValue> parent)
{
Current = default;
_index = 0;
_parent = parent;
}
public bool MoveNext()
{
Entry[] entries = _parent._entries;
for (uint index = _index; index < entries.Length; index++)
{
if (!EqualityComparer<TValue>.Default.Equals(entries[index].Value, default))
{
Current = new KeyValuePair<TKey, TValue>(entries[index].Key, entries[index].Value);
_index = index + 1;
return true;
}
}
_index = uint.MaxValue;
Current = new KeyValuePair<TKey, TValue>(default, default);
return false;
}
public void Reset()
{
throw new NotSupportedException();
}
public KeyValuePair<TKey, TValue> Current { get; private set; }
object IEnumerator.Current => Current;
public void Dispose() { }
}
}
public class TransformationCache<TKey, TValue> : BaseDictionary<TKey, TValue> where TKey : IEquatable<TKey>
{
public interface ITransformHandler<in TSource>
{
TKey KeyForValue(TSource value);
TValue Transform(TSource source);
}
public TValue GetOrAdd<TSource, TTransformHandler>(TSource sourceValue, TTransformHandler transformer) where TTransformHandler : ITransformHandler<TSource>
{
ref TValue valueRef = ref FindEntry(transformer.KeyForValue(sourceValue), true, out var found, default(EquatableKeyHandler));
if (!found)
{
valueRef = transformer.Transform(sourceValue);
}
return valueRef;
}
}
public class CountingDictionary<TKey> : BaseDictionary<TKey, int>, IEnumerable<KeyValuePair<TKey, int>> where TKey : IEquatable<TKey>
{
IEnumerator<KeyValuePair<TKey, int>> IEnumerable<KeyValuePair<TKey, int>>.GetEnumerator() => GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public Enumerator GetEnumerator() => new Enumerator(this);
public CountingDictionary(int initialSize = 64) : base(initialSize) { }
public int Increment(TKey key)
{
ref int value = ref FindEntry(key, true, out _, default(EquatableKeyHandler));
return ++value;
}
public int Decrement(TKey key)
{
ref int value = ref FindEntry(key, true, out _, default(EquatableKeyHandler));
return --value;
}
public int this[TKey key]
{
get => FindEntry(key, false, out _, default(EquatableKeyHandler));
set
{
ref int entry = ref FindEntry(key, true, out _, default(EquatableKeyHandler));
entry = value;
}
}
}
public class BoringButFastDictionary<TKey, TValue> : BaseDictionary<TKey, TValue>, IDictionary<TKey, TValue> where TKey : IEquatable<TKey>
{
public BoringButFastDictionary(int initialSize = 64) : base(initialSize) { }
public TValue AddOrUpdate(TKey key, Func<TKey, TValue> addValueFactory, Func<TKey, TValue, TValue> updateValueFactory)
{
ref TValue entryValue = ref FindEntry(key, true, out var found, default(EquatableKeyHandler));
if (found)
{
entryValue = updateValueFactory(key, entryValue);
}
else
{
entryValue = addValueFactory(key);
}
return entryValue;
}
public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory)
{
ref TValue entryValue = ref FindEntry(key, true, out var found, default(EquatableKeyHandler));
if (!found)
{
entryValue = valueFactory(key);
}
return entryValue;
}
IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() => GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public Enumerator GetEnumerator() => new Enumerator(this);
public void Add(KeyValuePair<TKey, TValue> item) => this[item.Key] = item.Value;
public void Clear() => RemoveAll();
public bool Contains(KeyValuePair<TKey, TValue> item)
{
var value = FindEntry(item.Key, false, out var found, default(EquatableKeyHandler));
return found && EqualityComparer<TValue>.Default.Equals(item.Value, value);
}
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{
foreach(var kvp in this)
{
array[arrayIndex] = kvp;
arrayIndex++;
}
}
public bool Remove(KeyValuePair<TKey, TValue> item)
{
if(Contains(item))
{
RemoveEntry(item.Key, default(EquatableKeyHandler));
return true;
}
return false;
}
public int Count => (int)EntryCount;
public bool IsReadOnly => false;
public bool ContainsKey(TKey key)
{
FindEntry(key, false, out var found, default(EquatableKeyHandler));
return found;
}
public void Add(TKey key, TValue value)
{
ref TValue valueRef = ref FindEntry(key, true, out var found, default(EquatableKeyHandler));
if (found) { throw new ArgumentException("Duplicate key"); }
valueRef = value;
}
public bool Remove(TKey key) => RemoveEntry(key, default(EquatableKeyHandler));
public bool TryGetValue(TKey key, out TValue value)
{
value = FindEntry(key, false, out var found, default(EquatableKeyHandler));
return found;
}
public TValue this[TKey key]
{
get => FindEntry(key, false, out _, default(EquatableKeyHandler));
set
{
ref TValue entry = ref FindEntry(key, true, out _, default(EquatableKeyHandler));
entry = value;
}
}
public ICollection<TKey> Keys => this.Select(kvp => kvp.Key).ToArray();
public ICollection<TValue> Values => this.Select(kvp => kvp.Value).ToArray();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment