Skip to content

Instantly share code, notes, and snippets.

@devshorts devshorts/BitGroup wrapper
Last active Feb 15, 2018

Embed
What would you like to do?
class BitGroupTests extends FlatSpec with Matchers with PropertyChecks {
"BitGroup" should "set values" in {
val max = 100
val group = BitGroup.zero(max)
(0 until max).map(group.valueAt).forall(_ == Bit.Zero)
val updated = group.setValues(List(57), Bit.One)
assert(updated.valueAt(57) == Bit.One)
assert(updated.setValues(List(57), Bit.Zero).valueAt(57) == Bit.Zero)
}
it should "treat 0 as an item" in {
val group = BitGroup.zero(100)
val updated = group.setValues(List(0), Bit.One)
assert(updated.valueAt(0) == Bit.One)
}
it should "set ones" in {
val max = 10000
implicit def arbInterval: Arbitrary[Int] = Arbitrary(Gen.choose(0, max))
forAll { (data: Set[Int]) =>
val items = data.filter(_ >= 0)
val zero = BitGroup.zero(max)
assert(zero.count == 0)
val group = zero.setValues(items.toList, Bit.One)
assert(items.map(group.valueAt).forall(_ == Bit.One))
assert(group.count == items.size)
}
}
it should "set zeros" in {
val max = 10000
implicit def arbInterval: Arbitrary[Int] = Arbitrary(Gen.choose(0, max))
forAll { (data: Set[Int]) =>
val items = data.filter(_ >= 0)
val zero = BitGroup.filled(max)
assert(zero.count == max)
val group = zero.setValues(items.toList, Bit.Zero)
assert(items.map(group.valueAt).forall(_ == Bit.Zero))
assert((max - group.count) == items.size)
}
}
}
object BitGroup {
def zero(max: Int): BitGroup = {
val bytesRequired = max / 8
new BitGroup(new Array[Byte](bytesRequired + 1), max)
}
def filled(max: Int): BitGroup = {
val bytesRequired = max / 8
val default = 0xFF
val array = new Array[Byte](bytesRequired + 1)
array.indices.foreach(array.update(_, default.toByte))
val unusedBits = (bytesRequired + 1) * 8 - max
val lastMax = (default >> unusedBits).toByte
// mask off the last unused bits to be 0
array.update(array.length -1, lastMax)
new BitGroup(array, max)
}
}
class BitGroup(val data: Array[Byte], max: Int) {
def count: Long = {
data.foldLeft(0)((count, byte) => count + setInByte(byte))
}
private def setInByte(byte: Byte): Int = {
(0 until 8).map(valueAt(byte, _)).count(_ == Bit.One)
}
def setValues(values: List[Int], bit: Bit): BitGroup = {
require(values.forall(_ <= max), s"Cannot set values above max $max")
val updatedData =
values.foldLeft(data)((bytes, v) => {
val position = byteAt(v)
val updatedByte =
(bit match {
case Bit.One =>
val mask = 1 << position.bitPosition
mask | position.byte
case Bit.Zero =>
val mask = ~(1 << position.bitPosition)
mask & position.byte
}).toByte
bytes.update(position.bytePosition, updatedByte)
bytes
})
new BitGroup(updatedData, max)
}
def valueAt(idx: Int): Bit = {
val position = byteAt(idx)
valueAt(position.byte, position.bitPosition)
}
private def valueAt(byte: Byte, bitPosition: Int): Bit = {
(byte >> bitPosition) & 1 match {
case 0 => Bit.Zero
case 1 => Bit.One
case _ => throw new RuntimeException("Never should have a non binary value for a bit!")
}
}
private def byteAt(idx: Int): BytePosition = {
val bytePosition = idx / 8
val byte = data(bytePosition)
val bitPosition = idx % 8
BytePosition(byte, bytePosition, bitPosition)
}
}
case class BytePosition(byte: Byte, bytePosition: Int, bitPosition: Int)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.