Skip to content

Instantly share code, notes, and snippets.

@sauravmishra1710
Last active May 31, 2021 06:21
Show Gist options
  • Save sauravmishra1710/dae46f15a024bc06c79e0049e6dabcf4 to your computer and use it in GitHub Desktop.
Save sauravmishra1710/dae46f15a024bc06c79e0049e6dabcf4 to your computer and use it in GitHub Desktop.
Visualize a Binary Tree.
class DisplayBinaryTree:
"""
Utility class to display a binary tree.
Referrence:
https://stackoverflow.com/a/54074933/599456
"""
def __init__(self):
pass
def displayTree(self, node):
"""
Displays a binary tree:
Args:
Node: the node of the tree. The root node.
"""
lines, *_ = self.__treeDisplayHelper(node)
for line in lines:
print(line)
def __treeDisplayHelper(self, node):
"""
Helper function for displaying a binary tree:
Returns list of strings, width, height,
and horizontal coordinate of the root.
Args:
Node: the node of the tree.
"""
# No child.
if node.right is None and node.left is None:
line = '%s' % node.value
width = len(line)
height = 1
middle = width // 2
return [line], width, height, middle
# Only left child.
if node.right is None:
lines, n, p, x = self.__treeDisplayHelper(node.left)
s = '%s' % node.value
u = len(s)
first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s
second_line = x * ' ' + '/' + (n - x - 1 + u) * ' '
shifted_lines = [line + u * ' ' for line in lines]
return [first_line, second_line] + shifted_lines, n + u, p + 2, n + u // 2
# Only right child.
if node.left is None:
lines, n, p, x = self.__treeDisplayHelper(node.right)
s = '%s' % node.value
u = len(s)
first_line = s + x * '_' + (n - x) * ' '
second_line = (u + x) * ' ' + '\\' + (n - x - 1) * ' '
shifted_lines = [u * ' ' + line for line in lines]
return [first_line, second_line] + shifted_lines, n + u, p + 2, u // 2
# Two children.
left, n, p, x = self.__treeDisplayHelper(node.left)
right, m, q, y = self.__treeDisplayHelper(node.right)
s = '%s' % node.value
u = len(s)
first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s + y * '_' + (m - y) * ' '
second_line = x * ' ' + '/' + (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
if p < q:
left += [n * ' '] * (q - p)
elif q < p:
right += [m * ' '] * (p - q)
zipped_lines = zip(left, right)
lines = [first_line, second_line] + [a + u * ' ' + b for a, b in zipped_lines]
return lines, n + m + u, max(p, q) + 2, n + u // 2
from DisplayBinaryTree import DisplayBinaryTree
bTreeViz = DisplayBinaryTree()
# create the binary tree
root = BinaryTreeNode(1)
node1 = BinaryTreeNode(2)
node2 = BinaryTreeNode(3)
node3 = BinaryTreeNode(4)
node4 = BinaryTreeNode(5)
node5 = BinaryTreeNode(6)
node6 = BinaryTreeNode(7)
node7 = BinaryTreeNode(8)
node8 = BinaryTreeNode(9)
node9 = BinaryTreeNode(10)
root.left = node1
root.right = node2
node1.left = node3
node1.right = node4
node2.left = node5
node2.right = node6
node3.left = node7
node3.right = node8
node4.left = node9
# display the tree
bTreeViz.displayTree(root)
# The final display of the tree created above.
# ___1_
# / \
# _2__ 3
# / \ / \
# 4 5 6 7
# / \ /
# 8 9 10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment