Skip to content

Instantly share code, notes, and snippets.

@sdcondon
Last active November 22, 2020 17:11
Show Gist options
  • Save sdcondon/43cfb8ed30f873817d4688c9664a9059 to your computer and use it in GitHub Desktop.
Save sdcondon/43cfb8ed30f873817d4688c9664a9059 to your computer and use it in GitHub Desktop.
Generic prefix tree, implementing IDictionary<TKey, TValue>
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
/// <summary>
/// Prefix tree implementation.
/// </summary>
/// <typeparam name="TKey">The type of the individual elements of the keys in the trie.</typeparam>
/// <typeparam name="TValue">The type of values in the trie.</typeparam>
public sealed class Trie<TKey, TValue> : IDictionary<IEnumerable<TKey>, TValue>
{
private readonly EqualityComparer<TKey> _equalityComparer;
/// <summary>
/// Initializes a new instance of the <see cref="Trie"/> class that uses the default equality comparer for the key.
/// </summary>
public Trie()
: this(EqualityComparer<TKey>.Default)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="Trie"/> class that uses a given equality comparer for the key.
/// </summary>
/// <param name="equalityComparer">The equality comparer to use.</param>
public Trie(EqualityComparer<TKey> equalityComparer)
{
_equalityComparer = equalityComparer;
Root = new Node(equalityComparer, null, default(TKey));
}
/// <inheritdoc />
public TValue this[IEnumerable<TKey> key]
{
get
{
if (!TryGetValue(key, out TValue value))
{
throw new KeyNotFoundException();
}
return value;
}
set
{
if (TryGetNode(key, false, out Node node))
{
node.Value = value;
}
else
{
Add(key, value);
}
}
}
/// <summary>
/// Gets the root node of the trie.
/// </summary>
internal Node Root
{
get;
}
/// <inheritdoc />
public int Count
{
get
{
// TODO: Poorer performance than necessary.
return GetValuedNodes().Count();
}
}
/// <inheritdoc />
public bool IsReadOnly
{
get
{
return false;
}
}
/// <inheritdoc />
public ICollection<IEnumerable<TKey>> Keys
{
get
{
return GetValuedNodes().Select(n => n.Key).ToArray();
}
}
/// <inheritdoc />
public ICollection<TValue> Values
{
get
{
return GetValuedNodes().Select(n => n.Value).ToArray();
}
}
/// <inheritdoc />
public void Add(KeyValuePair<IEnumerable<TKey>, TValue> item)
{
Add(item.Key, item.Value);
}
/// <inheritdoc />
public void Add(IEnumerable<TKey> key, TValue value)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
var node = Root;
foreach (var c in key)
{
node = node.AddChild(_equalityComparer, c);
}
if (node.HasValue)
{
throw new ArgumentException("An element with the same key already exists", nameof(key));
}
else
{
node.Value = value;
}
}
/// <inheritdoc />
public void Clear()
{
Root.RemoveChildren();
}
/// <inheritdoc />
public bool Contains(KeyValuePair<IEnumerable<TKey>, TValue> item)
{
bool containsKey = TryGetValue(item.Key, out TValue value);
return containsKey && value.Equals(item.Value);
}
/// <inheritdoc />
public bool ContainsKey(IEnumerable<TKey> key)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
return TryGetValue(key, out TValue value);
}
/// <inheritdoc />
public void CopyTo(KeyValuePair<IEnumerable<TKey>, TValue>[] array, int arrayIndex)
{
KeyValuePair<IEnumerable<TKey>, TValue>[] sourceArray = GetValuedNodes().Select(n => new KeyValuePair<IEnumerable<TKey>, TValue>(n.Key, n.Value)).ToArray();
Array.Copy(sourceArray, 0, array, arrayIndex, Count);
}
/// <inheritdoc />
public IEnumerator<KeyValuePair<IEnumerable<TKey>, TValue>> GetEnumerator()
{
return GetValuedNodes().Select(n => new KeyValuePair<IEnumerable<TKey>, TValue>(n.Key, n.Value)).GetEnumerator();
}
/// <inheritdoc />
public bool Remove(KeyValuePair<IEnumerable<TKey>, TValue> item)
{
if (!TryGetNode(item.Key, true, out Node node) || !node.Value.Equals(item.Value))
{
return false;
}
node.RemoveValue();
return true;
}
/// <inheritdoc />
public bool Remove(IEnumerable<TKey> key)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
if (!TryGetNode(key, true, out Node node))
{
return false;
}
node.RemoveValue();
return true;
}
/// <inheritdoc />
public bool TryGetValue(IEnumerable<TKey> key, out TValue value)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
if (!TryGetNode(key, true, out Node node))
{
value = default(TValue);
return false;
}
else
{
value = node.Value;
return true;
}
}
/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator()
{
return GetValuedNodes().Select(n => new KeyValuePair<IEnumerable<TKey>, TValue>(n.Key, n.Value)).GetEnumerator();
}
private IEnumerable<Node> GetValuedNodes()
{
List<Node> nodes = new List<Node>();
Root.CopyValuedDescendentsTo(nodes);
return nodes;
}
private bool TryGetNode(IEnumerable<TKey> key, bool mustHaveValue, out Node node)
{
node = Root;
foreach (var c in key)
{
if (!node.TryGetChild(c, out node))
{
return false;
}
}
return !mustHaveValue || node.HasValue;
}
internal class Node
{
private Node parent;
private TKey keyElement;
private Dictionary<TKey, Node> children;
private TValue value;
public Node(EqualityComparer<TKey> equalityComparer, Node parent, TKey keyElement)
{
this.parent = parent;
this.keyElement = keyElement;
children = new Dictionary<TKey, Node>(equalityComparer);
value = default(TValue);
HasValue = false;
}
public IEnumerable<TKey> Key
{
get
{
var stack = new Stack<TKey>();
for (Node node = this; node.parent != null; node = node.parent)
{
stack.Push(node.keyElement);
}
return stack.ToArray();
}
}
public bool HasValue
{
get;
private set;
}
public TValue Value
{
get
{
return value;
}
set
{
HasValue = true;
this.value = value;
}
}
public void RemoveValue()
{
HasValue = false;
// Prune the tree as necessary - traverse up the tree removing nodes
// until we hit one that has either a value or other children.
// TODO: Minor performance issue - should only need to make a single snip..
for (Node node = this; !node.HasValue && node.children.Count == 0 && node.parent != null; node = node.parent)
{
node.parent.children.Remove(node.keyElement);
}
}
public void CopyValuedDescendentsTo(ICollection<Node> collection)
{
//// TODO: Perhaps use our own stack instead of recursing..
if (HasValue)
{
collection.Add(this);
}
foreach (Node child in children.Values)
{
child.CopyValuedDescendentsTo(collection);
}
}
public bool TryGetChild(TKey key, out Node node)
{
return children.TryGetValue(key, out node);
}
public Node AddChild(EqualityComparer<TKey> equalityComparer, TKey key)
{
if (!children.TryGetValue(key, out Node child))
{
child = new Node(equalityComparer, this, key);
children.Add(key, child);
}
return child;
}
public void RemoveChildren()
{
children.Clear();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment