Create a gist now

Instantly share code, notes, and snippets.

Embed
What would you like to do?
class BitGroupTests extends FlatSpec with Matchers with PropertyChecks {
"BitGroup" should "set values" in {
val max = 100
val group = BitGroup.zero(max)
(0 until max).map(group.valueAt).forall(_ == Bit.Zero)
val updated = group.setValues(List(57), Bit.One)
assert(updated.valueAt(57) == Bit.One)
assert(updated.setValues(List(57), Bit.Zero).valueAt(57) == Bit.Zero)
}
it should "treat 0 as an item" in {
val group = BitGroup.zero(100)
val updated = group.setValues(List(0), Bit.One)
assert(updated.valueAt(0) == Bit.One)
}
it should "set ones" in {
val max = 10000
implicit def arbInterval: Arbitrary[Int] = Arbitrary(Gen.choose(0, max))
forAll { (data: Set[Int]) =>
val items = data.filter(_ >= 0)
val zero = BitGroup.zero(max)
assert(zero.count == 0)
val group = zero.setValues(items.toList, Bit.One)
assert(items.map(group.valueAt).forall(_ == Bit.One))
assert(group.count == items.size)
}
}
it should "set zeros" in {
val max = 10000
implicit def arbInterval: Arbitrary[Int] = Arbitrary(Gen.choose(0, max))
forAll { (data: Set[Int]) =>
val items = data.filter(_ >= 0)
val zero = BitGroup.filled(max)
assert(zero.count == max)
val group = zero.setValues(items.toList, Bit.Zero)
assert(items.map(group.valueAt).forall(_ == Bit.Zero))
assert((max - group.count) == items.size)
}
}
}
object BitGroup {
def zero(max: Int): BitGroup = {
val bytesRequired = max / 8
new BitGroup(new Array[Byte](bytesRequired + 1), max)
}
def filled(max: Int): BitGroup = {
val bytesRequired = max / 8
val default = 0xFF
val array = new Array[Byte](bytesRequired + 1)
array.indices.foreach(array.update(_, default.toByte))
val unusedBits = (bytesRequired + 1) * 8 - max
val lastMax = (default >> unusedBits).toByte
// mask off the last unused bits to be 0
array.update(array.length -1, lastMax)
new BitGroup(array, max)
}
}
class BitGroup(val data: Array[Byte], max: Int) {
def count: Long = {
data.foldLeft(0)((count, byte) => count + setInByte(byte))
}
private def setInByte(byte: Byte): Int = {
(0 until 8).map(valueAt(byte, _)).count(_ == Bit.One)
}
def setValues(values: List[Int], bit: Bit): BitGroup = {
require(values.forall(_ <= max), s"Cannot set values above max $max")
val updatedData =
values.foldLeft(data)((bytes, v) => {
val position = byteAt(v)
val updatedByte =
(bit match {
case Bit.One =>
val mask = 1 << position.bitPosition
mask | position.byte
case Bit.Zero =>
val mask = ~(1 << position.bitPosition)
mask & position.byte
}).toByte
bytes.update(position.bytePosition, updatedByte)
bytes
})
new BitGroup(updatedData, max)
}
def valueAt(idx: Int): Bit = {
val position = byteAt(idx)
valueAt(position.byte, position.bitPosition)
}
private def valueAt(byte: Byte, bitPosition: Int): Bit = {
(byte >> bitPosition) & 1 match {
case 0 => Bit.Zero
case 1 => Bit.One
case _ => throw new RuntimeException("Never should have a non binary value for a bit!")
}
}
private def byteAt(idx: Int): BytePosition = {
val bytePosition = idx / 8
val byte = data(bytePosition)
val bitPosition = idx % 8
BytePosition(byte, bytePosition, bitPosition)
}
}
case class BytePosition(byte: Byte, bytePosition: Int, bitPosition: Int)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment