Created
June 24, 2015 17:44
-
-
Save slowli/884fbc6ae07d44114ef3 to your computer and use it in GitHub Desktop.
Tree algorithms in Python with decorators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
''' | |
Implementation of trees with some useful methods: | |
copying, parsing and creating Newick tree representation, getting splits, | |
testing for equality, changing the root node, etc. | |
Most recurrent relationships in trees (e.g., tree splits) are implemented | |
using a custom function decorator. | |
Run as a script to test all methods. | |
''' | |
def leaves_to_root(fn): | |
''' Decorator for recurrent relations on trees. | |
Relation fn must accept one argument of the type Tree. | |
fn(v) should use the values of the function fn | |
on the descendants of v. ''' | |
cache = dict() | |
def ltr_fn(tree): | |
if id(tree) in cache: | |
result = cache[id(tree)] | |
else: | |
queue = tree.queue() | |
for node in reversed(queue): | |
cache[id(node)] = fn(node) | |
result = cache[id(tree)] | |
cache.clear() | |
return result | |
return ltr_fn | |
class Tree(object): | |
''' Generic tree with information in nodes. ''' | |
def __init__(self, data): | |
self.parent = None | |
self.data = data | |
self.children = [] | |
if isinstance(data, str) and data.endswith(';'): | |
self.parse_newick(data) | |
def prepend(self, child): | |
''' Adds a node as a first child of this node. ''' | |
if child.parent is not None: | |
child.parent.children.remove(child) | |
self.children[0:0] = [ child ] | |
child.parent = self | |
def append(self, child): | |
''' Adds a node as a last child of this node. ''' | |
if child.parent is not None: | |
child.parent.children.remove(child) | |
self.children.append(child) | |
child.parent = self | |
@leaves_to_root | |
def copy(self): | |
''' Returns a deep copy of this tree. ''' | |
copy = Tree(self.data) | |
for child in self.children: | |
copy.append(child.copy()) | |
return copy | |
def queue(self): | |
''' Returns a queue used during breadth-first searches | |
in this tree. Each node is placed in queue after its parent. ''' | |
queue = [ self ]; ptr = 0 | |
while ptr < len(queue): | |
queue.extend(queue[ptr].children) | |
ptr += 1 | |
return queue | |
def find(self, data): | |
''' Finds a node with the specified data attached among the descendants | |
of this node. ''' | |
for node in self.queue(): | |
if node.data == data: | |
return node | |
return None | |
def makeroot(self): | |
''' Rebases the tree, making this node its new root. ''' | |
node = self | |
parents = [] | |
while node.parent is not None: | |
parents.append(node) | |
node = node.parent | |
for node in reversed(parents): | |
node.parent.children.remove(node) | |
node.append(node.parent) | |
node.parent = None | |
return self | |
@leaves_to_root | |
def _rec_newick(self): | |
s = '' | |
if len(self.children) > 0: | |
s += '(' + ','.join([ child._rec_newick() for child in self.children ]) + ')' | |
return s + (str(self.data) if self.data is not None else '') | |
def newick(self): | |
''' Returns Newick notation for this tree. ''' | |
return self._rec_newick() + ';' | |
def __str__(self): | |
return 'Tree(\'' + self.newick() + '\')' | |
def __repr__(self): | |
return 'node(' + (str(self.data) if self.data is not None else '') + ')' | |
def parse_newick(self, text): | |
''' Parses Newick notation into a tree. ''' | |
pos = len(text) - 1 | |
parent = None | |
while pos >= 0: | |
if pos < len(text) - 1 and text[pos + 1] == '(': | |
# Node may not be here | |
pass | |
else: | |
init_pos = pos | |
# First, get node name | |
while not text[pos] in '(),;': | |
pos -= 1 | |
data = text[(pos + 1):(init_pos + 1)] | |
if len(data) == 0: data = None | |
if parent is None: | |
self.data = data; node = self | |
else: | |
node = Tree(data) | |
parent.prepend(node) | |
if text[pos] == '(': | |
if pos > 0: parent = parent.parent | |
elif text[pos] == ')': | |
parent = node | |
elif text[pos] == ',' or text[pos] == ';': | |
pass | |
pos -= 1 | |
def leaves(self): | |
''' Returns leaves of this tree. ''' | |
return [node for node in self.queue() \ | |
if len(node.children) < 2 ] | |
def leaves_data(self): | |
''' Returns a list of data associated with each leaf of this tree. ''' | |
return [node.data for node in self.leaves()] | |
@leaves_to_root | |
def all_splits(self): | |
''' Returns a list of all splits induced by the edges of this tree. | |
Each split is encoded by a list of leaves in one part of the split. ''' | |
if len(self.children) == 0: | |
return [ [ self ] ] | |
splits = [] | |
all_leaves = [] | |
for child in self.children: | |
child_splits = child.all_splits() | |
splits += child_splits | |
# the last split contains all leaves in the subtree | |
# rooted in the child node | |
all_leaves += child_splits[-1] | |
if len(self.children) == 1: # a leaf may be the root of the tree, too | |
all_leaves.append(self) | |
splits.append(all_leaves) | |
return splits | |
def splits(self): | |
''' Returns all non-trivial splits induced by the edges of this tree. | |
Each split is encoded by a list of leaves in one part of the split. ''' | |
splits = self.all_splits() | |
nleaves = len(splits[-1]) # number of leaves in the tree | |
# We need only non-trivial splits, i.e. splits | |
# with both parts consisting of more than one tree leaf | |
return [split for split in splits \ | |
if len(split) > 1 and len(split) < nleaves - 1] | |
def splits_data(self): | |
''' Returns all non-trivial splits induced by the edges of this tree. | |
Each split is encoded by a list of data associated with leaves | |
in one part of the split. ''' | |
to_data = lambda nodes: [node.data for node in nodes] | |
return map(to_data, self.splits()) | |
def __eq__(self, other): | |
if not isinstance(other, Tree): return False | |
leaves, oleaves = frozenset(self.leaves_data()), frozenset(other.leaves_data()) | |
if leaves != oleaves: return False | |
splits, osplits = map(frozenset, self.splits_data()), \ | |
map(frozenset, other.splits_data()) | |
for split in splits: | |
if (split not in osplits) and (leaves - split not in osplits): | |
return False | |
return True | |
def __ne__(self, other): | |
return not self.__eq__(other) | |
__hash__ = None # disable hashing (trees are mutable) | |
##### Global functions. ##### | |
@leaves_to_root | |
def nleaves(tree): | |
''' Counts leaves in the tree. ''' | |
if len(tree.children) == 0: return 1 | |
return sum([nleaves(ch) for ch in tree.children]) | |
def nleaves_im(tree): | |
''' Imperative implementation of counting leaves. ''' | |
n = dict() | |
for node in reversed(tree.queue()): | |
if len(node.children) == 0: | |
n[id(node)] = 1 | |
else: | |
n[id(node)] = sum([ n[id(child)] for child in node.children ]) | |
return n[id(tree)] | |
##### Testing ##### | |
import unittest | |
class Tester(unittest.TestCase): | |
@classmethod | |
def setUpClass(cls): | |
''' Creates a huge unrooted binary tree to use in tests. ''' | |
N = 1000 | |
cls.max_notation = '(' * N + 'A,' + \ | |
','.join([str(i) + ')' for i in range(N - 1)]) | |
cls.max_notation += ',' + str(N - 1) + ',' + str(N) + ');' | |
cls.max_tree = Tree(cls.max_notation) | |
cls.max_leaves = N + 2 | |
def test_parsing(self): | |
''' Tests parsing Newick format notation. ''' | |
tree = Tree('(A,B,C);') | |
self.assertEqual(3, len(tree.children)) | |
self.assertEqual(None, tree.data) | |
self.assertEqual('A', tree.children[0].data) | |
self.assertEqual('B', tree.children[1].data) | |
self.assertEqual('C', tree.children[2].data) | |
def test_str(self): | |
''' Tests creating Newick format representation. ''' | |
notations = [ | |
'A;', | |
'(A,B,C);', | |
'(A,,C)D;', | |
'((,A,)B,((C,D),E),(F,G))H;', | |
'(,,((,),),);' | |
] | |
for notation in notations: | |
tree = Tree(notation) | |
self.assertEqual(tree.newick(), notation) | |
def test_max(self): | |
''' Tests creating Newick format representation for big trees | |
(does not work if we define Tree.newick() recursively). ''' | |
tree = self.max_tree | |
self.assertEqual(tree.newick(), self.max_notation) | |
def test_leaves(self): | |
''' Tests Tree.leaves() and Tree.leaves_data() methods. ''' | |
tree = Tree('((A,B),(C,D),(E,F));') | |
leaves = tree.leaves() | |
self.assertEqual(6, len(leaves)) | |
data = tree.leaves_data() | |
self.assertEqual(6, len(data)) | |
for leaf in leaves: | |
self.assertIn(leaf.data, data) | |
for leaf in ['A', 'B', 'C', 'D', 'E', 'F']: | |
self.assertIn(leaf, data) | |
def test_splits(self): | |
''' Tests Tree.splits_data() method. ''' | |
tree = Tree('((A,B),(C,D),(E,F));') | |
splits = tree.splits_data() | |
self.assertEqual(3, len(splits)) | |
self.assertIn(['A', 'B'], splits) | |
self.assertIn(['C', 'D'], splits) | |
self.assertIn(['E', 'F'], splits) | |
def test_splits_big(self): | |
''' Tests the speed of Tree.splits() method. ''' | |
tree = self.max_tree | |
self.assertEqual(self.max_leaves - 3, len(tree.splits())) | |
def test_eq(self): | |
tree = Tree('((A,B),(C,D),(E,F));') | |
other = Tree('(B,A,((E,F),(D,C)));') | |
self.assertEqual(tree, other) | |
# Tree rooted at the leaf node with additional markers for inner nodes | |
other = Tree('((((B,A)0,(E,F)1)2,C)3)D;') | |
self.assertEqual(tree, other) | |
def test_makeroot(self): | |
''' Tests Tree.makeroot() method. ''' | |
tree = Tree('(((A,B)C,D)E,(G,H)I)F;') | |
tree = tree.find('C').makeroot() | |
self.assertEqual('(A,B,(D,((G,H)I)F)E)C;', tree.newick()) | |
def test_equal_copy(self): | |
''' Tests whether a copy of a tree is equal to it. ''' | |
tree = self.max_tree | |
self.assertEqual(tree, tree.copy()) | |
def test_equal_rebased_copy(self): | |
''' Test equality between a tree and the rebased tree. ''' | |
tree = Tree('((A,(B,H)),(C,D),(E,(F,G)));') | |
copy = tree.copy().find('D').parent.makeroot() | |
self.assertEqual(tree, copy) | |
def test_equal_rebased_copy_max(self): | |
''' Test equality between a huge tree and the rebased tree. ''' | |
tree = self.max_tree | |
copy = tree.copy() | |
copy = copy.find(str(self.max_leaves / 2)).parent.makeroot() | |
self.assertEqual(tree, copy) | |
def test_neq(self): | |
tree = Tree('((A,B),(C,D),(E,F));') | |
# Trees with other labels | |
other = Tree('((A,B),(C,D),(E,G));') | |
self.assertNotEqual(tree, other) | |
# Trees with different number of nodes | |
other = Tree('((A,B),(C,D),E);') | |
self.assertNotEqual(tree, other) | |
other = Tree('((A,B),(C,(D,G)),(E,F));') | |
self.assertNotEqual(tree, other) | |
# Tree with swapped nodes | |
other = Tree('((A,C),(B,D),(E,F));') | |
self.assertNotEqual(tree, other) | |
def test_nleaves(self): | |
''' Tests counting leaves in trees. ''' | |
tree = Tree('((A,B),(C,D),(E,F));') | |
self.assertEqual(6, nleaves(tree)) | |
self.assertEqual(6, nleaves_im(tree)) | |
def test_nleaves_max(self): | |
''' Tests counting leaves in trees (big tree). ''' | |
tree = self.max_tree | |
self.assertEqual(self.max_leaves, nleaves(tree)) | |
self.assertEqual(self.max_leaves, nleaves_im(tree)) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment