Last active
November 22, 2020 17:11
-
-
Save sdcondon/43cfb8ed30f873817d4688c9664a9059 to your computer and use it in GitHub Desktop.
Generic prefix tree, implementing IDictionary<TKey, TValue>
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections; | |
using System.Collections.Generic; | |
using System.Linq; | |
/// <summary> | |
/// Prefix tree implementation. | |
/// </summary> | |
/// <typeparam name="TKey">The type of the individual elements of the keys in the trie.</typeparam> | |
/// <typeparam name="TValue">The type of values in the trie.</typeparam> | |
public sealed class Trie<TKey, TValue> : IDictionary<IEnumerable<TKey>, TValue> | |
{ | |
private readonly EqualityComparer<TKey> _equalityComparer; | |
/// <summary> | |
/// Initializes a new instance of the <see cref="Trie"/> class that uses the default equality comparer for the key. | |
/// </summary> | |
public Trie() | |
: this(EqualityComparer<TKey>.Default) | |
{ | |
} | |
/// <summary> | |
/// Initializes a new instance of the <see cref="Trie"/> class that uses a given equality comparer for the key. | |
/// </summary> | |
/// <param name="equalityComparer">The equality comparer to use.</param> | |
public Trie(EqualityComparer<TKey> equalityComparer) | |
{ | |
_equalityComparer = equalityComparer; | |
Root = new Node(equalityComparer, null, default(TKey)); | |
} | |
/// <inheritdoc /> | |
public TValue this[IEnumerable<TKey> key] | |
{ | |
get | |
{ | |
if (!TryGetValue(key, out TValue value)) | |
{ | |
throw new KeyNotFoundException(); | |
} | |
return value; | |
} | |
set | |
{ | |
if (TryGetNode(key, false, out Node node)) | |
{ | |
node.Value = value; | |
} | |
else | |
{ | |
Add(key, value); | |
} | |
} | |
} | |
/// <summary> | |
/// Gets the root node of the trie. | |
/// </summary> | |
internal Node Root | |
{ | |
get; | |
} | |
/// <inheritdoc /> | |
public int Count | |
{ | |
get | |
{ | |
// TODO: Poorer performance than necessary. | |
return GetValuedNodes().Count(); | |
} | |
} | |
/// <inheritdoc /> | |
public bool IsReadOnly | |
{ | |
get | |
{ | |
return false; | |
} | |
} | |
/// <inheritdoc /> | |
public ICollection<IEnumerable<TKey>> Keys | |
{ | |
get | |
{ | |
return GetValuedNodes().Select(n => n.Key).ToArray(); | |
} | |
} | |
/// <inheritdoc /> | |
public ICollection<TValue> Values | |
{ | |
get | |
{ | |
return GetValuedNodes().Select(n => n.Value).ToArray(); | |
} | |
} | |
/// <inheritdoc /> | |
public void Add(KeyValuePair<IEnumerable<TKey>, TValue> item) | |
{ | |
Add(item.Key, item.Value); | |
} | |
/// <inheritdoc /> | |
public void Add(IEnumerable<TKey> key, TValue value) | |
{ | |
if (key == null) | |
{ | |
throw new ArgumentNullException(nameof(key)); | |
} | |
var node = Root; | |
foreach (var c in key) | |
{ | |
node = node.AddChild(_equalityComparer, c); | |
} | |
if (node.HasValue) | |
{ | |
throw new ArgumentException("An element with the same key already exists", nameof(key)); | |
} | |
else | |
{ | |
node.Value = value; | |
} | |
} | |
/// <inheritdoc /> | |
public void Clear() | |
{ | |
Root.RemoveChildren(); | |
} | |
/// <inheritdoc /> | |
public bool Contains(KeyValuePair<IEnumerable<TKey>, TValue> item) | |
{ | |
bool containsKey = TryGetValue(item.Key, out TValue value); | |
return containsKey && value.Equals(item.Value); | |
} | |
/// <inheritdoc /> | |
public bool ContainsKey(IEnumerable<TKey> key) | |
{ | |
if (key == null) | |
{ | |
throw new ArgumentNullException(nameof(key)); | |
} | |
return TryGetValue(key, out TValue value); | |
} | |
/// <inheritdoc /> | |
public void CopyTo(KeyValuePair<IEnumerable<TKey>, TValue>[] array, int arrayIndex) | |
{ | |
KeyValuePair<IEnumerable<TKey>, TValue>[] sourceArray = GetValuedNodes().Select(n => new KeyValuePair<IEnumerable<TKey>, TValue>(n.Key, n.Value)).ToArray(); | |
Array.Copy(sourceArray, 0, array, arrayIndex, Count); | |
} | |
/// <inheritdoc /> | |
public IEnumerator<KeyValuePair<IEnumerable<TKey>, TValue>> GetEnumerator() | |
{ | |
return GetValuedNodes().Select(n => new KeyValuePair<IEnumerable<TKey>, TValue>(n.Key, n.Value)).GetEnumerator(); | |
} | |
/// <inheritdoc /> | |
public bool Remove(KeyValuePair<IEnumerable<TKey>, TValue> item) | |
{ | |
if (!TryGetNode(item.Key, true, out Node node) || !node.Value.Equals(item.Value)) | |
{ | |
return false; | |
} | |
node.RemoveValue(); | |
return true; | |
} | |
/// <inheritdoc /> | |
public bool Remove(IEnumerable<TKey> key) | |
{ | |
if (key == null) | |
{ | |
throw new ArgumentNullException(nameof(key)); | |
} | |
if (!TryGetNode(key, true, out Node node)) | |
{ | |
return false; | |
} | |
node.RemoveValue(); | |
return true; | |
} | |
/// <inheritdoc /> | |
public bool TryGetValue(IEnumerable<TKey> key, out TValue value) | |
{ | |
if (key == null) | |
{ | |
throw new ArgumentNullException(nameof(key)); | |
} | |
if (!TryGetNode(key, true, out Node node)) | |
{ | |
value = default(TValue); | |
return false; | |
} | |
else | |
{ | |
value = node.Value; | |
return true; | |
} | |
} | |
/// <inheritdoc /> | |
IEnumerator IEnumerable.GetEnumerator() | |
{ | |
return GetValuedNodes().Select(n => new KeyValuePair<IEnumerable<TKey>, TValue>(n.Key, n.Value)).GetEnumerator(); | |
} | |
private IEnumerable<Node> GetValuedNodes() | |
{ | |
List<Node> nodes = new List<Node>(); | |
Root.CopyValuedDescendentsTo(nodes); | |
return nodes; | |
} | |
private bool TryGetNode(IEnumerable<TKey> key, bool mustHaveValue, out Node node) | |
{ | |
node = Root; | |
foreach (var c in key) | |
{ | |
if (!node.TryGetChild(c, out node)) | |
{ | |
return false; | |
} | |
} | |
return !mustHaveValue || node.HasValue; | |
} | |
internal class Node | |
{ | |
private Node parent; | |
private TKey keyElement; | |
private Dictionary<TKey, Node> children; | |
private TValue value; | |
public Node(EqualityComparer<TKey> equalityComparer, Node parent, TKey keyElement) | |
{ | |
this.parent = parent; | |
this.keyElement = keyElement; | |
children = new Dictionary<TKey, Node>(equalityComparer); | |
value = default(TValue); | |
HasValue = false; | |
} | |
public IEnumerable<TKey> Key | |
{ | |
get | |
{ | |
var stack = new Stack<TKey>(); | |
for (Node node = this; node.parent != null; node = node.parent) | |
{ | |
stack.Push(node.keyElement); | |
} | |
return stack.ToArray(); | |
} | |
} | |
public bool HasValue | |
{ | |
get; | |
private set; | |
} | |
public TValue Value | |
{ | |
get | |
{ | |
return value; | |
} | |
set | |
{ | |
HasValue = true; | |
this.value = value; | |
} | |
} | |
public void RemoveValue() | |
{ | |
HasValue = false; | |
// Prune the tree as necessary - traverse up the tree removing nodes | |
// until we hit one that has either a value or other children. | |
// TODO: Minor performance issue - should only need to make a single snip.. | |
for (Node node = this; !node.HasValue && node.children.Count == 0 && node.parent != null; node = node.parent) | |
{ | |
node.parent.children.Remove(node.keyElement); | |
} | |
} | |
public void CopyValuedDescendentsTo(ICollection<Node> collection) | |
{ | |
//// TODO: Perhaps use our own stack instead of recursing.. | |
if (HasValue) | |
{ | |
collection.Add(this); | |
} | |
foreach (Node child in children.Values) | |
{ | |
child.CopyValuedDescendentsTo(collection); | |
} | |
} | |
public bool TryGetChild(TKey key, out Node node) | |
{ | |
return children.TryGetValue(key, out node); | |
} | |
public Node AddChild(EqualityComparer<TKey> equalityComparer, TKey key) | |
{ | |
if (!children.TryGetValue(key, out Node child)) | |
{ | |
child = new Node(equalityComparer, this, key); | |
children.Add(key, child); | |
} | |
return child; | |
} | |
public void RemoveChildren() | |
{ | |
children.Clear(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment