Skip to content

Instantly share code, notes, and snippets.

@slowli
Created June 24, 2015 17:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save slowli/884fbc6ae07d44114ef3 to your computer and use it in GitHub Desktop.
Save slowli/884fbc6ae07d44114ef3 to your computer and use it in GitHub Desktop.
Tree algorithms in Python with decorators
#!/usr/bin/python
'''
Implementation of trees with some useful methods:
copying, parsing and creating Newick tree representation, getting splits,
testing for equality, changing the root node, etc.
Most recurrent relationships in trees (e.g., tree splits) are implemented
using a custom function decorator.
Run as a script to test all methods.
'''
def leaves_to_root(fn):
''' Decorator for recurrent relations on trees.
Relation fn must accept one argument of the type Tree.
fn(v) should use the values of the function fn
on the descendants of v. '''
cache = dict()
def ltr_fn(tree):
if id(tree) in cache:
result = cache[id(tree)]
else:
queue = tree.queue()
for node in reversed(queue):
cache[id(node)] = fn(node)
result = cache[id(tree)]
cache.clear()
return result
return ltr_fn
class Tree(object):
''' Generic tree with information in nodes. '''
def __init__(self, data):
self.parent = None
self.data = data
self.children = []
if isinstance(data, str) and data.endswith(';'):
self.parse_newick(data)
def prepend(self, child):
''' Adds a node as a first child of this node. '''
if child.parent is not None:
child.parent.children.remove(child)
self.children[0:0] = [ child ]
child.parent = self
def append(self, child):
''' Adds a node as a last child of this node. '''
if child.parent is not None:
child.parent.children.remove(child)
self.children.append(child)
child.parent = self
@leaves_to_root
def copy(self):
''' Returns a deep copy of this tree. '''
copy = Tree(self.data)
for child in self.children:
copy.append(child.copy())
return copy
def queue(self):
''' Returns a queue used during breadth-first searches
in this tree. Each node is placed in queue after its parent. '''
queue = [ self ]; ptr = 0
while ptr < len(queue):
queue.extend(queue[ptr].children)
ptr += 1
return queue
def find(self, data):
''' Finds a node with the specified data attached among the descendants
of this node. '''
for node in self.queue():
if node.data == data:
return node
return None
def makeroot(self):
''' Rebases the tree, making this node its new root. '''
node = self
parents = []
while node.parent is not None:
parents.append(node)
node = node.parent
for node in reversed(parents):
node.parent.children.remove(node)
node.append(node.parent)
node.parent = None
return self
@leaves_to_root
def _rec_newick(self):
s = ''
if len(self.children) > 0:
s += '(' + ','.join([ child._rec_newick() for child in self.children ]) + ')'
return s + (str(self.data) if self.data is not None else '')
def newick(self):
''' Returns Newick notation for this tree. '''
return self._rec_newick() + ';'
def __str__(self):
return 'Tree(\'' + self.newick() + '\')'
def __repr__(self):
return 'node(' + (str(self.data) if self.data is not None else '') + ')'
def parse_newick(self, text):
''' Parses Newick notation into a tree. '''
pos = len(text) - 1
parent = None
while pos >= 0:
if pos < len(text) - 1 and text[pos + 1] == '(':
# Node may not be here
pass
else:
init_pos = pos
# First, get node name
while not text[pos] in '(),;':
pos -= 1
data = text[(pos + 1):(init_pos + 1)]
if len(data) == 0: data = None
if parent is None:
self.data = data; node = self
else:
node = Tree(data)
parent.prepend(node)
if text[pos] == '(':
if pos > 0: parent = parent.parent
elif text[pos] == ')':
parent = node
elif text[pos] == ',' or text[pos] == ';':
pass
pos -= 1
def leaves(self):
''' Returns leaves of this tree. '''
return [node for node in self.queue() \
if len(node.children) < 2 ]
def leaves_data(self):
''' Returns a list of data associated with each leaf of this tree. '''
return [node.data for node in self.leaves()]
@leaves_to_root
def all_splits(self):
''' Returns a list of all splits induced by the edges of this tree.
Each split is encoded by a list of leaves in one part of the split. '''
if len(self.children) == 0:
return [ [ self ] ]
splits = []
all_leaves = []
for child in self.children:
child_splits = child.all_splits()
splits += child_splits
# the last split contains all leaves in the subtree
# rooted in the child node
all_leaves += child_splits[-1]
if len(self.children) == 1: # a leaf may be the root of the tree, too
all_leaves.append(self)
splits.append(all_leaves)
return splits
def splits(self):
''' Returns all non-trivial splits induced by the edges of this tree.
Each split is encoded by a list of leaves in one part of the split. '''
splits = self.all_splits()
nleaves = len(splits[-1]) # number of leaves in the tree
# We need only non-trivial splits, i.e. splits
# with both parts consisting of more than one tree leaf
return [split for split in splits \
if len(split) > 1 and len(split) < nleaves - 1]
def splits_data(self):
''' Returns all non-trivial splits induced by the edges of this tree.
Each split is encoded by a list of data associated with leaves
in one part of the split. '''
to_data = lambda nodes: [node.data for node in nodes]
return map(to_data, self.splits())
def __eq__(self, other):
if not isinstance(other, Tree): return False
leaves, oleaves = frozenset(self.leaves_data()), frozenset(other.leaves_data())
if leaves != oleaves: return False
splits, osplits = map(frozenset, self.splits_data()), \
map(frozenset, other.splits_data())
for split in splits:
if (split not in osplits) and (leaves - split not in osplits):
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
__hash__ = None # disable hashing (trees are mutable)
##### Global functions. #####
@leaves_to_root
def nleaves(tree):
''' Counts leaves in the tree. '''
if len(tree.children) == 0: return 1
return sum([nleaves(ch) for ch in tree.children])
def nleaves_im(tree):
''' Imperative implementation of counting leaves. '''
n = dict()
for node in reversed(tree.queue()):
if len(node.children) == 0:
n[id(node)] = 1
else:
n[id(node)] = sum([ n[id(child)] for child in node.children ])
return n[id(tree)]
##### Testing #####
import unittest
class Tester(unittest.TestCase):
@classmethod
def setUpClass(cls):
''' Creates a huge unrooted binary tree to use in tests. '''
N = 1000
cls.max_notation = '(' * N + 'A,' + \
','.join([str(i) + ')' for i in range(N - 1)])
cls.max_notation += ',' + str(N - 1) + ',' + str(N) + ');'
cls.max_tree = Tree(cls.max_notation)
cls.max_leaves = N + 2
def test_parsing(self):
''' Tests parsing Newick format notation. '''
tree = Tree('(A,B,C);')
self.assertEqual(3, len(tree.children))
self.assertEqual(None, tree.data)
self.assertEqual('A', tree.children[0].data)
self.assertEqual('B', tree.children[1].data)
self.assertEqual('C', tree.children[2].data)
def test_str(self):
''' Tests creating Newick format representation. '''
notations = [
'A;',
'(A,B,C);',
'(A,,C)D;',
'((,A,)B,((C,D),E),(F,G))H;',
'(,,((,),),);'
]
for notation in notations:
tree = Tree(notation)
self.assertEqual(tree.newick(), notation)
def test_max(self):
''' Tests creating Newick format representation for big trees
(does not work if we define Tree.newick() recursively). '''
tree = self.max_tree
self.assertEqual(tree.newick(), self.max_notation)
def test_leaves(self):
''' Tests Tree.leaves() and Tree.leaves_data() methods. '''
tree = Tree('((A,B),(C,D),(E,F));')
leaves = tree.leaves()
self.assertEqual(6, len(leaves))
data = tree.leaves_data()
self.assertEqual(6, len(data))
for leaf in leaves:
self.assertIn(leaf.data, data)
for leaf in ['A', 'B', 'C', 'D', 'E', 'F']:
self.assertIn(leaf, data)
def test_splits(self):
''' Tests Tree.splits_data() method. '''
tree = Tree('((A,B),(C,D),(E,F));')
splits = tree.splits_data()
self.assertEqual(3, len(splits))
self.assertIn(['A', 'B'], splits)
self.assertIn(['C', 'D'], splits)
self.assertIn(['E', 'F'], splits)
def test_splits_big(self):
''' Tests the speed of Tree.splits() method. '''
tree = self.max_tree
self.assertEqual(self.max_leaves - 3, len(tree.splits()))
def test_eq(self):
tree = Tree('((A,B),(C,D),(E,F));')
other = Tree('(B,A,((E,F),(D,C)));')
self.assertEqual(tree, other)
# Tree rooted at the leaf node with additional markers for inner nodes
other = Tree('((((B,A)0,(E,F)1)2,C)3)D;')
self.assertEqual(tree, other)
def test_makeroot(self):
''' Tests Tree.makeroot() method. '''
tree = Tree('(((A,B)C,D)E,(G,H)I)F;')
tree = tree.find('C').makeroot()
self.assertEqual('(A,B,(D,((G,H)I)F)E)C;', tree.newick())
def test_equal_copy(self):
''' Tests whether a copy of a tree is equal to it. '''
tree = self.max_tree
self.assertEqual(tree, tree.copy())
def test_equal_rebased_copy(self):
''' Test equality between a tree and the rebased tree. '''
tree = Tree('((A,(B,H)),(C,D),(E,(F,G)));')
copy = tree.copy().find('D').parent.makeroot()
self.assertEqual(tree, copy)
def test_equal_rebased_copy_max(self):
''' Test equality between a huge tree and the rebased tree. '''
tree = self.max_tree
copy = tree.copy()
copy = copy.find(str(self.max_leaves / 2)).parent.makeroot()
self.assertEqual(tree, copy)
def test_neq(self):
tree = Tree('((A,B),(C,D),(E,F));')
# Trees with other labels
other = Tree('((A,B),(C,D),(E,G));')
self.assertNotEqual(tree, other)
# Trees with different number of nodes
other = Tree('((A,B),(C,D),E);')
self.assertNotEqual(tree, other)
other = Tree('((A,B),(C,(D,G)),(E,F));')
self.assertNotEqual(tree, other)
# Tree with swapped nodes
other = Tree('((A,C),(B,D),(E,F));')
self.assertNotEqual(tree, other)
def test_nleaves(self):
''' Tests counting leaves in trees. '''
tree = Tree('((A,B),(C,D),(E,F));')
self.assertEqual(6, nleaves(tree))
self.assertEqual(6, nleaves_im(tree))
def test_nleaves_max(self):
''' Tests counting leaves in trees (big tree). '''
tree = self.max_tree
self.assertEqual(self.max_leaves, nleaves(tree))
self.assertEqual(self.max_leaves, nleaves_im(tree))
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment