Last active
June 1, 2020 06:57
-
-
Save chez-shanpu/2a9f09f59b0cf42356e03c2d73acbcdf to your computer and use it in GitHub Desktop.
AVL木
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"math" | |
) | |
type binaryTree struct { | |
root *node | |
} | |
type node struct { | |
value int | |
height int | |
left *node | |
right *node | |
} | |
func (b *binaryTree) add(val int) { | |
if b.root == nil { | |
b.root = &node{value: val} | |
} else { | |
b.root = b.root.add(val) | |
} | |
} | |
func (b *binaryTree) contains(t int) bool { | |
n := b.root | |
for { | |
if t == n.value { | |
return true | |
} else if t < n.value { | |
if n.left == nil { | |
return false | |
} | |
n = n.left | |
} else if t > n.value { | |
if n.right == nil { | |
return false | |
} | |
n = n.right | |
} | |
} | |
} | |
func (n *node) computeHeight() int { | |
h := -1 | |
if n.left != nil { | |
h = int(math.Max(float64(h), float64(n.left.height))) | |
} else if n.right != nil { | |
h = int(math.Max(float64(h), float64(n.right.height))) | |
} | |
return h + 1 | |
} | |
func (n *node) heightDiff() int { | |
l := 0 | |
r := 0 | |
if n.left != nil { | |
l = n.left.height | |
} else if n.right != nil { | |
r = n.right.height | |
} | |
return int(math.Abs(float64(l - r))) | |
} | |
func (n *node) add(val int) *node { | |
root := n | |
if val < n.value { | |
n.left = n.addToSubTree(n.left, val) | |
if n.heightDiff() > 1 { | |
if val <= n.left.value { | |
root = n.rotateRight() | |
} else { | |
root = n.rotateRightLeft() | |
} | |
} | |
} else { | |
n.right = n.addToSubTree(n.right, val) | |
if n.heightDiff() > 1 { | |
if val > n.right.value { | |
root = n.rotateLeft() | |
} else { | |
root = n.rotateRightLeft() | |
} | |
} | |
} | |
root.height = root.computeHeight() | |
return root | |
} | |
func (n *node) addToSubTree(p *node, val int) *node { | |
if p == nil { | |
return &node{value: val} | |
} | |
p = p.add(val) | |
return p | |
} | |
func (n *node) rotateRight() *node { | |
root := n.left | |
grandson := root.right | |
n.left = grandson | |
root.right = n | |
n.height = n.computeHeight() | |
return root | |
} | |
// right - left pattern | |
func (n *node) rotateRightLeft() *node { | |
child := n.right | |
root := child.left | |
grand1 := root.left | |
grand2 := root.right | |
child.left = grand2 | |
n.right = grand1 | |
root.left = n | |
root.right = child | |
child.height = child.computeHeight() | |
n.height = n.computeHeight() | |
return root | |
} | |
func (n *node) rotateLeft() *node { | |
root := n.right | |
grandson := root.left | |
n.right = grandson | |
root.left = n | |
n.height = n.computeHeight() | |
return root | |
} | |
// left-right pattern | |
func (n *node) rotateLeftRight() *node { | |
child := n.left | |
root := child.right | |
grand1 := root.left | |
grand2 := root.right | |
child.right = grand2 | |
n.left = grand1 | |
root.left = child | |
root.right = n | |
child.height = child.computeHeight() | |
n.height = n.computeHeight() | |
return root | |
} | |
func (n *node) removeFromParent(p *node, val int) *node { | |
if p != nil { | |
return p.remove(val) | |
} | |
return nil | |
} | |
func (n *node) remove(val int) *node { | |
root := n | |
if val == n.value { | |
if n.left == nil { | |
return n.right | |
} | |
child := n.left | |
for child.right != nil { | |
child = child.right | |
} | |
n.left = n.removeFromParent(n.left, child.value) | |
n.value = child.value | |
if n.heightDiff() > 1 { | |
if n.right.heightDiff() <= 1 { | |
root = n.rotateLeft() | |
} else { | |
root = n.rotateRightLeft() | |
} | |
} | |
} else if val < n.value { | |
n.left = n.removeFromParent(n.left, val) | |
if n.heightDiff() > 1 { | |
if n.right.heightDiff() <= 1 { | |
root = n.rotateLeft() | |
} else { | |
root = n.rotateRightLeft() | |
} | |
} | |
} else { | |
n.right = n.removeFromParent(n.right, val) | |
if n.heightDiff() > 1 { | |
if n.left.heightDiff() <= 1 { | |
root = n.rotateRight() | |
} else { | |
root = n.rotateLeftRight() | |
} | |
} | |
} | |
root.height = root.computeHeight() | |
return root | |
} | |
func (n *node) printInorder() { | |
if n.left != nil { | |
n.left.printInorder() | |
} | |
fmt.Printf("%d ",n.value) | |
if n.right != nil { | |
n.right.printInorder() | |
} | |
} | |
func main() { | |
b := binaryTree{} | |
b.add(5) | |
b.add(13) | |
b.add(3) | |
b.add(9) | |
b.add(34) | |
fmt.Println(b.contains(13)) | |
fmt.Println(b.contains(4)) | |
fmt.Println("-------------------") | |
b.root.printInorder() | |
fmt.Println("\n-------------------") | |
b.root.remove(13) | |
fmt.Println("-------------------") | |
b.root.printInorder() | |
fmt.Println("\n-------------------") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment