Skip to content

Instantly share code, notes, and snippets.

@nilium
Created March 5, 2014 02:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nilium/9359868 to your computer and use it in GitHub Desktop.
Save nilium/9359868 to your computer and use it in GitHub Desktop.
Quick bitset written in Go. No use for it yet, so in a gist it goes.
package main
import "log"
import "fmt"
import "bytes"
type BitSet struct {
bits []uint64
length uint64
}
const storageBitLength = 64
func indexAndMaskOfBit(bit uint64) (uint64, uint64) {
index := bit / storageBitLength
mask := uint64(0x1) << (bit % storageBitLength)
return index, mask
}
func (bits *BitSet) String() string {
buf := new(bytes.Buffer)
length := bits.length
b := bits.bits
for index := uint64(0); index < length; index++ {
value := b[length-(index+1)]
buf.WriteString(fmt.Sprintf("%064b", value))
}
underlying := buf.Bytes()
firstBit := bytes.IndexRune(underlying, '1')
if firstBit == -1 {
return "0"
} else {
return bytes.NewBuffer(underlying[firstBit:]).String()
}
}
func NewBitSet() *BitSet {
return &BitSet{nil, 0}
}
func NewBitSetWithLength(length uint64) *BitSet {
return &BitSet{make([]uint64, length), length}
}
func NewBitSetWithBits(bits ...uint64) *BitSet {
bitset := &BitSet{make([]uint64, len(bits)), uint64(len(bits))}
copy(bitset.bits, bits)
return bitset
}
func (bits *BitSet) ensureLength(length uint64) {
if length > bits.length {
newBits := make([]uint64, length)
copy(newBits, bits.bits)
bits.bits = newBits
bits.length = length
}
}
func (bits *BitSet) Equal(other *BitSet) bool {
if bits == other {
return true
}
lhs, lhsOk := bits.NextBit(0)
rhs, rhsOk := other.NextBit(0)
for lhsOk && rhsOk && lhs == rhs {
lhs, lhsOk = bits.NextBit(lhs + 1)
rhs, rhsOk = other.NextBit(rhs + 1)
}
return lhs == rhs && lhsOk == rhsOk
}
func (bits *BitSet) Set(bit uint64) *BitSet {
index, mask := indexAndMaskOfBit(bit)
bits.ensureLength(index + 1)
bits.bits[index] |= mask
return bits
}
func (bits *BitSet) Unset(bit uint64) *BitSet {
index, mask := indexAndMaskOfBit(bit)
bits.ensureLength(index + 1)
bits.bits[index] &^= mask
return bits
}
func (bits *BitSet) Flip(bit uint64) *BitSet {
index, mask := indexAndMaskOfBit(bit)
bits.ensureLength(index + 1)
bits.bits[index] ^= mask
return bits
}
func (bits *BitSet) Union(other *BitSet) *BitSet {
if other.length > bits.length {
return other.Union(bits)
}
dup := bits.Clone()
for index, num := range other.bits {
dup.bits[index] |= num
}
return dup
}
func (bits *BitSet) Intersection(other *BitSet) *BitSet {
if other.length < bits.length {
return other.Intersection(bits)
}
dup := bits.Clone()
for index := range dup.bits {
dup.bits[index] &= other.bits[index]
}
return dup
}
func (bits *BitSet) Difference(other *BitSet) *BitSet {
dup := bits.Clone()
until := bits.length
if until > other.length {
until = other.length
}
for index := uint64(0); index < until; index++ {
dup.bits[index] &^= other.bits[index]
}
return dup
}
func (bits *BitSet) Xor(other *BitSet) *BitSet {
if other.length > bits.length {
return other.Union(bits)
}
dup := bits.Clone()
for index, num := range other.bits {
dup.bits[index] ^= num
}
return dup
}
func (bits *BitSet) Test(bit uint64) bool {
index, mask := indexAndMaskOfBit(bit)
if index < bits.length {
return (bits.bits[index] & mask) != 0
}
return false
}
func (bits *BitSet) Clone() *BitSet {
clone := &BitSet{make([]uint64, bits.length), bits.length}
copy(clone.bits, bits.bits)
return clone
}
// Complement flips all bits in the BitSet. Bits outside the capacity of the
// BitSet remain zeroed, and as such adding new bits outside the storage
// capacity of the BitSet will introduce unset bits.
// Returns the receiver.
func (bits *BitSet) Complement() *BitSet {
for index := range bits.bits {
bits.bits[index] = ^bits.bits[index]
}
return bits
}
func (bits *BitSet) NextBit(from uint64) (uint64, bool) {
index := from / storageBitLength
offset := from % storageBitLength
for ; index < bits.length; index++ {
for b := bits.bits[index] >> offset; b != 0; b >>= 1 {
if (b & 0x1) != 0 {
return from, true
}
from++
}
from = (index + 1) * storageBitLength
offset = 0
}
return 0, false
}
func assert(cond bool, format string, a ...interface{}) {
if !cond {
panic(fmt.Errorf(format, a...))
}
}
func main() {
bits := NewBitSetWithBits(0x1)
log.Println(bits)
set := func(b uint64) {
p := bits.Test(b)
bits.Set(b)
n := bits.Test(b)
assert(n, "Bit %d not set", b)
log.Printf("Setting %d (previous: %t, current: %t)\n", b, p, n)
}
unset := func(b uint64) {
p := bits.Test(b)
bits.Unset(b)
n := bits.Test(b)
assert(!n, "Bit %d is set", b)
log.Printf("Unsetting %d (previous: %t, current: %t)\n", b, p, n)
}
flip := func(b uint64) {
p := bits.Test(b)
bits.Flip(b)
n := bits.Test(b)
assert(n == !p, "Bit %d not flipped", b)
log.Printf("Flipping %d (previous: %t, current: %t)\n", b, p, n)
}
assert(bits.Test(0), "First bit is not set")
set(0)
set(0)
set(1)
set(1)
set(2)
set(2)
set(2049)
set(2049)
set(4000)
set(4000)
flip(0)
flip(0)
flip(1)
flip(1)
flip(2)
flip(2)
flip(2049)
flip(2049)
flip(4000)
flip(4000)
unset(0)
unset(0)
unset(1)
unset(1)
unset(2)
unset(2)
unset(2049)
unset(2049)
unset(4000)
unset(4000)
set(0)
set(0)
set(1)
set(1)
set(2)
set(2)
set(2049)
set(2049)
set(4000)
set(4000)
for b, ok := bits.NextBit(0); ok; b, ok = bits.NextBit(b + 1) {
log.Printf("%d: %t\n", b, ok)
}
lhs := NewBitSetWithBits(0xFF, 0xF0)
rhs := NewBitSetWithBits(0xF, 0xFF)
lhsClone := lhs.Clone()
assert(lhs.Equal(lhs), "Same instances not equal")
assert(!rhs.Equal(lhs), "Different inequal instances compared equal")
assert(lhs.Equal(lhsClone), "Different equal instances compared inequal")
log.Printf("lhs: %080s\n", lhs)
log.Printf("rhs: %080s\n", rhs)
log.Printf("Union: %080s\n", lhs.Union(rhs))
log.Printf("Intersection: %080s\n", lhs.Intersection(rhs))
log.Printf("Difference: %080s\n", lhs.Difference(rhs))
log.Printf("Xor: %080s\n", lhs.Xor(rhs))
// log.Println(bits)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment