Skip to content

Instantly share code, notes, and snippets.

@fbparis
Last active March 12, 2019 01:25
Show Gist options
  • Save fbparis/e114958a4a9be84146a1a579cf677e8d to your computer and use it in GitHub Desktop.
Save fbparis/e114958a4a9be84146a1a579cf677e8d to your computer and use it in GitHub Desktop.
A very powerful and memory safe data-structure to replace python's dict: Indexed Radix Trie
"""
A Python3 indexed trie class.
An indexed trie's key can be any subscriptable object.
Keys of the indexed trie are stored using a "radix trie", a space-optimized data-structure which has many advantages (see https://en.wikipedia.org/wiki/Radix_tree).
Also, each key in the indexed trie is associated to a unique index which is build dynamically.
Indexed trie is used like a python dictionary (and even a collections.defaultdict if you want to) but its values can also be accessed or updated (but not created) like a list!
Example:
>>> t = indextrie()
>>> t["abc"] = "hello"
>>> t[0]
'hello'
>>> t["abc"]
'hello'
>>> t.index2key(0)
'abc'
>>> t.key2index("abc")
0
>>> t[:]
[0]
>>> print(t)
{(0, 'abc'): hello}
"""
__author__ = "@fbparis"
_SENTINEL = object()
class _Node(object):
"""
A single node in the trie.
"""
__slots__ = "_children", "_parent", "_index", "_key"
def __init__(self, key, parent, index=None):
self._children = set()
self._key = key
self._parent = parent
self._index = index
self._parent._children.add(self)
class IndexedtrieKey(object):
"""
A pair (index, key) acting as an indexedtrie's key
"""
__slots__ = "index", "key"
def __init__(self, index, key):
self.index = index
self.key = key
def __repr__(self):
return "(%d, %s)" % (self.index, self.key)
class indexedtrie(object):
"""
The indexed trie data-structure.
"""
__slots__ = "_children", "_indexes", "_values", "_nodescount", "_default_factory"
def __init__(self, items=None, default_factory=_SENTINEL):
"""
A list of items can be passed to initialize the indexed trie.
"""
self._children = set()
self.setdefault(default_factory)
self._indexes = []
self._values = []
self._nodescount = 0 # keeping track of nodes count is purely informational
if items is not None:
for k, v in items:
if isinstance(k, IndexedtrieKey):
self.__setitem__(k.key, v)
else:
self.__setitem__(k, v)
@classmethod
def fromkeys(cls, keys, value=_SENTINEL, default_factory=_SENTINEL):
"""
Build a new indexedtrie from a list of keys.
"""
obj = cls(default_factory=default_factory)
for key in keys:
if value is _SENTINEL:
if default_factory is not _SENTINEL:
obj[key] = obj._default_factory()
else:
obj[key] = None
else:
obj[key] = value
return obj
@classmethod
def fromsplit(cls, keys, value=_SENTINEL, default_factory=_SENTINEL):
"""
Build a new indexedtrie from a splitable object.
"""
obj = cls(default_factory=default_factory)
for key in keys.split():
if value is _SENTINEL:
if default_factory is not _SENTINEL:
obj[key] = obj._default_factory()
else:
obj[key] = None
else:
obj[key] = value
return obj
def setdefault(self, factory=_SENTINEL):
"""
"""
if factory is not _SENTINEL:
# indexed trie will act like a collections.defaultdict except in some cases because the __missing__
# method is not implemented here (on purpose).
# That means that simple lookups on a non existing key will return a default value without adding
# the key, which is the more logical way to do.
# Also means that if your default_factory is for example "list", you won't be able to create new
# items with "append" or "extend" methods which are updating the list itself.
# Instead you have to do something like trie["newkey"] += [...]
try:
_ = factory()
except TypeError:
# a default value is also accepted as default_factory, even "None"
self._default_factory = lambda: factory
else:
self._default_factory = factory
else:
self._default_factory = _SENTINEL
def copy(self):
"""
Return a pseudo-shallow copy of the indexedtrie.
Keys and nodes are deepcopied, but if you store some referenced objects in values, only the references will be copied.
"""
return self.__class__(self.items(), default_factory=self._default_factory)
def __len__(self):
return len(self._indexes)
def __repr__(self):
if self._default_factory is not _SENTINEL:
default = ", default_value=%s" % self._default_factory()
else:
default = ""
return "<%s object at %s: %d items, %d nodes%s>" % (self.__class__.__name__, hex(id(self)), len(self), self._nodescount, default)
def __str__(self):
ret = ["%s: %s" % (k, v) for k, v in self.items()]
return "{%s}" % ", ".join(ret)
def __iter__(self):
return self.keys()
def __contains__(self, key_or_index):
"""
Return True if the key or index exists in the indexed trie.
"""
if isinstance(key_or_index, IndexedtrieKey):
return key_or_index.index >= 0 and key_or_index.index < len(self)
if isinstance(key_or_index, int):
return key_or_index >= 0 and key_or_index < len(self)
if self._seems_valid_key(key_or_index):
try:
node = self._get_node(key_or_index)
except KeyError:
return False
else:
return node._index is not None
raise TypeError("invalid key type")
def __getitem__(self, key_or_index):
"""
"""
if isinstance(key_or_index, IndexedtrieKey):
return self._values[key_or_index.index]
if isinstance(key_or_index, int) or isinstance(key_or_index, slice):
return self._values[key_or_index]
if self._seems_valid_key(key_or_index):
try:
node = self._get_node(key_or_index)
except KeyError:
if self._default_factory is _SENTINEL:
raise
else:
return self._default_factory()
else:
if node._index is None:
if self._default_factory is _SENTINEL:
raise KeyError
else:
return self._default_factory()
else:
return self._values[node._index]
raise TypeError("invalid key type")
def __setitem__(self, key_or_index, value):
"""
"""
if isinstance(key_or_index, IndexedtrieKey):
self._values[key_or_index.index] = value
elif isinstance(key_or_index, int):
self._values[key_or_index] = value
elif isinstance(key_or_index, slice):
raise NotImplementedError
elif self._seems_valid_key(key_or_index):
try:
node = self._get_node(key_or_index)
except KeyError:
# create a new node
self._add_node(key_or_index, value)
else:
if node._index is None:
# if node exists but not indexed, we index it and update the value
self._add_to_index(node, value)
else:
# else we update its value
self._values[node._index] = value
else:
raise TypeError("invalid key type")
def __delitem__(self, key_or_index):
"""
"""
if isinstance(key_or_index, IndexedtrieKey):
node = self._indexes[key_or_index.index]
elif isinstance(key_or_index, int):
node = self._indexes[key_or_index]
elif isinstance(key_or_index, slice):
raise NotImplementedError
elif self._seems_valid_key(key_or_index):
node = self._get_node(key_or_index)
if node._index is None:
raise KeyError
else:
raise TypeError("invalid key type")
# switch last index with deleted index (except if deleted index is last index)
last_node, last_value = self._indexes.pop(), self._values.pop()
if node._index != last_node._index:
last_node._index = node._index
self._indexes[node._index] = last_node
self._values[node._index] = last_value
if len(node._children) > 1:
#case 1: node has more than 1 child, only turn index off
node._index = None
elif len(node._children) == 1:
# case 2: node has 1 child
child = node._children.pop()
child._key = node._key + child._key
child._parent = node._parent
node._parent._children.add(child)
node._parent._children.remove(node)
del(node)
self._nodescount -= 1
else:
# case 3: node has no child, check the parent node
parent = node._parent
parent._children.remove(node)
del(node)
self._nodescount -= 1
if hasattr(parent, "_index"):
if parent._index is None and len(parent._children) == 1:
node = parent._children.pop()
node._key = parent._key + node._key
node._parent = parent._parent
parent._parent._children.add(node)
parent._parent._children.remove(parent)
del(parent)
self._nodescount -= 1
@staticmethod
def _seems_valid_key(key):
"""
Return True if "key" can be a valid key (must be subscriptable).
"""
try:
_ = key[:0]
except TypeError:
return False
return True
def keys(self, prefix=None):
"""
Yield keys stored in the indexedtrie where key is a IndexedtrieKey object.
If prefix is given, yield only keys of items with key matching the prefix.
"""
if prefix is None:
for i, node in enumerate(self._indexes):
yield IndexedtrieKey(i, self._get_key(node))
else:
if self._seems_valid_key(prefix):
empty = prefix[:0]
children = [(empty, prefix, child) for child in self._children]
while len(children):
_children = []
for key, prefix, child in children:
if prefix == child._key[:len(prefix)]:
_key = key + child._key
_children.extend([(_key, empty, _child) for _child in child._children])
if child._index is not None:
yield IndexedtrieKey(child._index, _key)
elif prefix[:len(child._key)] == child._key:
_prefix = prefix[len(child._key):]
_key = key + prefix[:len(child._key)]
_children.extend([(_key, _prefix, _child) for _child in child._children])
children = _children
else:
raise ValueError("invalid prefix type")
def values(self, prefix=None):
"""
Yield values stored in the indexedtrie.
If prefix is given, yield only values of items with key matching the prefix.
"""
if prefix is None:
for value in self._values:
yield value
else:
for key in self.keys(prefix):
yield self._values[key.index]
def items(self, prefix=None):
"""
Yield (key, value) pairs stored in the indexedtrie where key is a IndexedtrieKey object.
If prefix is given, yield only (key, value) pairs of items with key matching the prefix.
"""
for key in self.keys(prefix):
yield key, self._values[key.index]
def show_tree(self, node=None, level=0):
"""
Pretty print the internal trie (recursive function).
"""
if node is None:
node = self
for child in node._children:
print("-" * level + "<key=%s, index=%s>" % (child._key, child._index))
if len(child._children):
self.show_tree(child, level + 1)
def _get_node(self, key):
"""
Return the node associated to key or raise a KeyError.
"""
children = self._children
while len(children):
notfound = True
for child in children:
if key == child._key:
return child
if child._key == key[:len(child._key)]:
children = child._children
key = key[len(child._key):]
notfound = False
break
if notfound:
break
raise KeyError
def _add_node(self, key, value):
"""
Add a new key in the trie and updates indexes and values.
"""
children = self._children
parent = self
moved = None
done = len(children) == 0
# we want to insert key="abc"
while not done:
done = True
for child in children:
# assert child._key != key # uncomment if you don't trust me
if child._key == key[:len(child._key)]:
# case 1: child's key is "ab", insert "c" in child's children
parent = child
children = child._children
key = key[len(child._key):]
done = len(children) == 0
break
elif key == child._key[:len(key)]:
# case 2: child's key is "abcd", we insert "abc" in place of the child
# child's parent will be the inserted node and child's key is now "d"
parent = child._parent
moved = child
parent._children.remove(moved)
moved._key = moved._key[len(key):]
break
elif type(key) is type(child._key): # don't mess it up
# find longest common prefix
prefix = key[:0]
for i, c in enumerate(key):
if child._key[i] != c:
prefix = key[:i]
break
if prefix:
# case 3: child's key is abd, we spawn a new node with key "ab"
# to replace child ; child's key is now "d" and child's parent is
# the new created node.
# the new node will also be inserted as a child of this node
# with key "c"
node = _Node(prefix, child._parent)
self._nodescount += 1
child._parent._children.remove(child)
child._key = child._key[len(prefix):]
child._parent = node
node._children.add(child)
key = key[len(prefix):]
parent = node
break
# create the new node
node = _Node(key, parent)
self._nodescount += 1
if moved is not None:
# if we have moved an existing node, update it
moved._parent = node
node._children.add(moved)
self._add_to_index(node, value)
def _get_key(self, node):
"""
Rebuild key from a terminal node.
"""
key = node._key
while node._parent is not self:
node = node._parent
key = node._key + key
return key
def _add_to_index(self, node, value):
"""
Add a new node to the index.
Also record its value.
"""
node._index = len(self)
self._indexes.append(node)
self._values.append(value)
def key2index(self, key):
"""
key -> index
"""
if self._seems_valid_key(key):
node = self._get_node(key)
if node._index is not None:
return node._index
raise KeyError
raise TypeError("invalid key type")
def index2key(self, index):
"""
index or IndexedtrieKey -> key.
"""
if isinstance(index, IndexedtrieKey):
index = index.index
elif not isinstance(index, int):
raise TypeError("index must be an int")
if index < 0 or index > len(self._indexes):
raise IndexError
return self._get_key(self._indexes[index])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment