Skip to content

Instantly share code, notes, and snippets.

@quatrix
Last active August 29, 2015 14:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save quatrix/02c94e8bc94d166dfa7c to your computer and use it in GitHub Desktop.
Save quatrix/02c94e8bc94d166dfa7c to your computer and use it in GitHub Desktop.
"""
Write code that, given a tree with weighted nodes, prints, for each node,
the weight of the subtree rooted at that node (each node only knows it's own ID,
it's parent's ID and it's own weight).
e.g
you get a csv with the following format:
id, parent, weight
and need to return a csv with id, and total weight of the subtree rooted at that tree
"""
from __future__ import print_function
from random import shuffle
from unittest import TestCase
class Node(object):
def __init__(self, node_id, parent_id, node_weight):
self.node_id = node_id
self.parent_id = parent_id
self.node_weight = node_weight
self.children = []
@property
def is_root(self):
return self.parent_id == -1
@property
def is_leaf(self):
return not bool(self.children)
def _get_subtree_weight(self):
return sum(
[child.subtree_weight for child in self.children],
self.node_weight
)
@property
def subtree_weight(self):
if self.is_leaf:
return self.node_weight
if not hasattr(self, "_subtree_weight"):
self._subtree_weight = self._get_subtree_weight()
return self._subtree_weight
def get_subtree_weights(relations):
tree = {}
for node in (Node(*relation) for relation in relations):
tree[node.node_id] = node
for node_id, node in tree.iteritems():
if not node.is_root:
tree[node.parent_id].children.append(node)
return tree
"""
Test code starts here:
"""
class Tree(object):
def __init__(self, weight):
self.weight = weight
self.parent = -1
self.children = []
self.subtree_weight = weight
@property
def id(self):
return id(self)
def add_children(self, *children):
for child in children:
child.parent = self.id
self.subtree_weight += child.subtree_weight
self.children.append(child)
return self
def create_workset(trees):
relations = []
subtree_weights = {}
def _rec(t):
relations.append([t.id, t.parent, t.weight])
subtree_weights[t.id] = t.subtree_weight
for child in t.children:
_rec(child)
for tree in trees:
_rec(tree)
shuffle(relations)
return relations, subtree_weights
class SubTreeWeightTestCase(TestCase):
def setUp(self):
self.tree = None
def assertSubTreeWeights(self):
assert [tree for tree in self.trees if tree]
relations, expected_subtree_weights = create_workset(self.trees)
actual_subtree_weights = get_subtree_weights(relations)
for id, subtree_weight in expected_subtree_weights.iteritems():
self.assertEqual(
subtree_weight, actual_subtree_weights[id].subtree_weight
)
def test_basic_tree(self):
self.trees = [
Tree(5).add_children(
Tree(10).add_children(Tree(7), Tree(5)),
Tree(3)
),
]
self.assertSubTreeWeights()
def test_multiple_trees(self):
self.trees = [
Tree(5).add_children(
Tree(10).add_children(Tree(7), Tree(5)),
Tree(3)
),
Tree(13).add_children(
Tree(11).add_children(Tree(88), Tree(12)),
Tree(4).add_children(Tree(3))
),
]
self.assertSubTreeWeights()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment