Skip to content

Instantly share code, notes, and snippets.

@gorakhargosh
Created September 29, 2015 19:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gorakhargosh/7602accaf0b31380839b to your computer and use it in GitHub Desktop.
Save gorakhargosh/7602accaf0b31380839b to your computer and use it in GitHub Desktop.
package main
import (
"fmt"
"log"
"os"
)
// Partition is an equivalence relation over disjoint sets.
type Partition interface {
// Union performs a union of the sets that x and y belong to.
Union(x, y int)
// FindSet finds the representative element of the set that x belongs to.
FindSet(x int) int
// Connected determines whether two elements, x and y, are connected.
Connected(x, y int) bool
// Weight determines the number of items in the set represented by x.
Weight(x int) uint64
}
// pathShortenedPartition is a path-compressed weighted quick-union disjoint set
// partitioner.
type pathShortenedPartition struct {
id []int
weight []uint64
}
// NewPathShortenedPartition generates a new partition.
func NewPathShortenedPartition(size int) Partition {
p := &pathShortenedPartition{
id: make([]int, size),
weight: make([]uint64, size),
}
for i := 0; i < size; i++ {
p.id[i] = i
p.weight[i] = uint64(1)
}
return p
}
func (p *pathShortenedPartition) Union(x, y int) {
a := p.FindSet(x)
b := p.FindSet(y)
if p.weight[a] < p.weight[b] {
p.id[a] = b
p.weight[b] += p.weight[a]
} else {
p.id[b] = a
p.weight[a] += p.weight[b]
}
}
// func (p *pathShortenedPartition) FindSet(x int) int {
// for x != p.id[x] {
// p.id[x] = p.id[p.id[x]]
// x = p.id[x]
// }
// return x
// }
func (p *pathShortenedPartition) FindSet(x int) int {
// Two pass variant that sets root for all traversed elements.
i := x
for x != p.id[x] {
x = p.id[x]
}
// x is now the root.
for i != p.id[i] {
i = p.id[i]
p.id[i] = x
}
return x
}
func (p pathShortenedPartition) Weight(x int) uint64 {
return p.weight[p.FindSet(x)]
}
func (p pathShortenedPartition) Connected(x, y int) bool {
return p.FindSet(x) == p.FindSet(y)
}
// readInts reads a slice of integers from standard input.
func readInts(n int) []int {
li := make([]int, n)
p := make([]interface{}, n)
for i := 0; i < n; i++ {
p[i] = &li[i]
}
b, err := fmt.Scanln(p...)
li = li[:b]
if err != nil {
log.Fatal(err)
}
return li
}
func readInt() int {
var a int
_, err := fmt.Scan(&a)
if err != nil {
log.Fatal(err)
}
return a
}
func main() {
ints := readInts(2)
n, q := ints[0], ints[1]
p := NewPathShortenedPartition(n + 1)
for ; q > 0; q-- {
c := " "
fmt.Scan(&c)
switch c {
case "M":
i := 0
j := 0
fmt.Scanln(&i, &j)
// Without this connected check, you'd still end up
// with an unbalanced tree.
if !p.Connected(i, j) {
p.Union(i, j)
}
case "Q":
fmt.Fprintf(os.Stdout, "%d\n", p.Weight(readInt()))
default:
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment