Skip to content

Instantly share code, notes, and snippets.

@bdw
Last active December 10, 2015 19:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bdw/4485184 to your computer and use it in GitHub Desktop.
Save bdw/4485184 to your computer and use it in GitHub Desktop.
Binary tree in Go
package main
import (
"fmt"
"strconv"
"io"
"encoding/binary"
)
type BinaryTree struct {
Value, depth int
Left, Right *BinaryTree
}
type QueueElement struct {
Payload *BinaryTree
Next *QueueElement
}
type Queue struct {
First, Last *QueueElement
}
func (q *Queue) Push(b *BinaryTree) {
elt := &QueueElement{b, nil}
if q.First == nil {
q.First = elt
q.Last = elt
} else {
q.Last.Next = elt
q.Last = elt
}
}
func (q *Queue) Shift() *BinaryTree {
if q.First == nil {
return nil
}
tree := q.First.Payload
q.First = q.First.Next
return tree
}
func New(v int) *BinaryTree {
root := new(BinaryTree)
root.Value = v
return root
}
func (n *BinaryTree) recalculateDepth() int {
if n.Left == nil && n.Right == nil {
n.depth = 0
} else if n.Right == nil {
n.depth = n.Left.depth + 1
} else if n.Left == nil || n.Right.depth > n.Left.depth {
n.depth = n.Right.depth + 1
} else {
n.depth = n.Left.depth + 1
}
return n.depth
}
func (n *BinaryTree) Add(v int) {
if v < n.Value {
if n.Left == nil {
n.Left = New(v)
} else {
n.Left.Add(v)
}
} else if v > n.Value {
if n.Right == nil {
n.Right = New(v)
} else {
n.Right.Add(v)
}
}
n.recalculateDepth()
}
func (n *BinaryTree) Find(v int) bool {
if v == n.Value {
return true
} else if v < n.Value && n.Left != nil {
return n.Left.Find(v)
} else if v > n.Value && n.Right != nil {
return n.Right.Find(v)
}
return false
}
func (n *BinaryTree) String() string {
str := strconv.Itoa(n.Value)
if n.Left != nil {
str += n.Left.String()
}
if n.Right != nil {
str += n.Right.String()
}
return str
}
func (n *BinaryTree) Depth() int {
v, l, r := 0, 0, 0
if n.Left != nil {
l = n.Left.Depth() + 1
}
if n.Right != nil {
r = n.Right.Depth() + 1
}
if l > r {
return v + l
}
return v + r
}
func (n *BinaryTree) Size() int {
v := 1
if n.Left != nil {
v += n.Left.Size()
}
if n.Right != nil {
v += n.Right.Size()
}
return v
}
func (n *BinaryTree) Rebalance() {
// our aim is to create a tree that is equally deep on both sides
ls, rs := 0, 0
if n.Left != nil {
ls = n.Left.depth
}
if n.Right != nil {
rs = n.Right.depth
}
for ls > (rs + 1) {
// the strategy is to shift nodes in-place to the left
// the node to the left is going to be moved to the
// right side while its value is going to the root
// node, and its children will be shuffled so as to
// preserve the correct order
movedNode, tempValue := n.Left, n.Value
n.Left = n.Left.Left
movedNode.Left = movedNode.Right;
movedNode.Right = n.Right
n.Right = movedNode
n.Value = movedNode.Value
movedNode.Value = tempValue
// afterwards, the depth of each node will have to be
// recomputed because one of the left-side child nodes
// moves to the right alongside the moved node
rs = movedNode.recalculateDepth()
n.recalculateDepth()
ls = n.Left.depth
}
for rs > (ls + 1) {
// the same process happens in reverse for
// right-to-left shifts
movedNode, tempValue := n.Right, n.Value
n.Right = n.Right.Right
movedNode.Right = movedNode.Left
movedNode.Left = n.Left
n.Left = movedNode
n.Value = movedNode.Value
movedNode.Value = tempValue
ls = movedNode.recalculateDepth()
n.recalculateDepth()
rs = n.Right.depth
}
// we rebalance the children after the parents because the
// tree shifts involved would otherwise give the wrong result
if n.Left != nil {
n.Left.Rebalance()
}
if n.Right != nil {
n.Right.Rebalance()
}
}
func readNode(rdr io.Reader) *BinaryTree {
var value int
err := binary.Read(rdr, binary.LittleEndian, &value)
if err != nil {
return nil
}
return New(value)
}
func ReadTree(rdr io.Reader) {
queue := Queue{nil,nil}
node := readNode(rdr)
for node != nil {
node.Left = readNode(rdr)
node.Right = readNode(rdr)
if node.Left != nil {
queue.Push(node.Left)
}
if node.Right != nil {
queue.Push(node.Right)
}
node = queue.Shift() // this totally works out
}
}
const SIZE int = 8
func main() {
tree := New(SIZE)
for i := SIZE; i > 0; i-- {
tree.Add(i)
}
fmt.Println(tree)
fmt.Println(tree.depth, tree.Size())
tree.Rebalance()
fmt.Println(tree, tree.Depth(), tree.Size())
for i := 1; i < SIZE; i++ {
if !tree.Find(i) {
fmt.Println("ERROR: Could not find", i)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment