Skip to content

Instantly share code, notes, and snippets.

@XMPPwocky
Created August 28, 2015 21:12
Show Gist options
  • Save XMPPwocky/ef521cea192b98fa5210 to your computer and use it in GitHub Desktop.
Save XMPPwocky/ef521cea192b98fa5210 to your computer and use it in GitHub Desktop.
use std::cmp::min;
use self::BitBufError::*;
pub struct BitBuf {
contents: Vec<u8>,
current_bit: u32,
bit_count: u32
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum BitBufError {
OutOfRoom,
BadNumberOfBits
}
pub type BitBufResult<T> = Result<T, BitBufError>;
impl BitBuf {
pub fn new(contents: Vec<u8>, bit_count: u32) -> BitBuf {
assert!((bit_count / 8) as usize <= contents.len());
BitBuf {
contents: contents,
current_bit: 0,
bit_count: bit_count
}
}
pub fn empty() -> BitBuf {
BitBuf {
contents: Vec::new(),
current_bit: 0,
bit_count: 0
}
}
pub fn tell(&self) -> u32 {
self.current_bit
}
pub fn seek(&mut self, pos: u32) {
self.current_bit = pos;
}
pub fn read_bits_as_u32(&mut self, bit_count: u32) -> BitBufResult<u32> {
if bit_count > 32 {
return Err(BadNumberOfBits);
}
let startbit = self.current_bit;
let endbit = self.current_bit + bit_count;
if endbit > self.bit_count {
return Err(BitBufError::OutOfRoom);
}
let startbyte = startbit / 8;
let endbyte = endbit / 8 ;
debug_assert!(endbyte as usize <= self.contents.len());
let mut val = 0;
let mut bits_read = 0;
let mut skipbits = startbit % 8;
for i in startbyte..endbyte + 1 {
let byte = self.contents[i as usize] as u32;
let numbitsskipped = min(skipbits, 8);
skipbits -= numbitsskipped;
let truncatebit = if i == endbyte {
endbit % 8
} else {
8
};
let maskedbyte = byte & make_bitmask(numbitsskipped, truncatebit);
let shiftedbyte = maskedbyte >> numbitsskipped;
val = val | (shiftedbyte << bits_read);
bits_read += truncatebit - numbitsskipped;
}
self.current_bit += bits_read;
debug_assert_eq!(bits_read, bit_count);
debug_assert!(self.current_bit <= self.bit_count);
Ok(val)
}
pub fn write_u32_as_bits(&mut self, val: u32, bit_count: u32) -> BitBufResult<()> {
if bit_count > 32 {
return Err(BitBufError::BadNumberOfBits);
}
let startbit = self.current_bit;
let endbit = self.current_bit + bit_count;
let startbyte = startbit / 8;
let endbyte = endbit / 8 ;
let padbytes = endbyte as usize + 1 - self.contents.len();
self.contents.extend(::std::iter::repeat(0)
.take(padbytes));
self.bit_count += padbytes as u32 * 8;
let mut bits_written = 0;
let mut skipbits = startbit % 8;
for i in startbyte..endbyte + 1 {
let numbitsskipped = min(skipbits, 8);
skipbits -= numbitsskipped;
let truncatebit = if i == endbyte {
endbit % 8
} else {
8
};
let maskedbyte = ((val >> bits_written) << numbitsskipped) & make_bitmask(numbitsskipped, truncatebit);
self.contents[i as usize] = maskedbyte as u8;
bits_written += truncatebit - numbitsskipped;
}
self.current_bit += bits_written;
debug_assert_eq!(bits_written, bit_count);
debug_assert!(self.current_bit <= self.bit_count);
Ok(())
}
pub fn read_bits_as_i32(&mut self, bit_count: u32) -> BitBufResult<i32> {
let val = try!(self.read_bits_as_u32(bit_count)) as i32;
let max_neg = 1 << (bit_count - 1);
if val > max_neg {
// it's negative (two's complement); fix sign
Ok(val - (2 * max_neg))
} else {
Ok(val)
}
}
pub fn write_i32_as_bits(&mut self, val: i32, bit_count: u32) -> BitBufResult<()> {
// FIXME: there should be a check for val being too big here
self.write_u32_as_bits(val as u32, bit_count)
}
pub fn bytes(&self) -> &[u8] {
&self.contents[0..(self.bit_count / 8) as usize]
}
}
fn make_bitmask(startbit: u32, endbit: u32) -> u32 {
let mask = 0xFFFFFFFF << startbit; // mask LSBs
mask & !(0xFFFFFFFF << endbit) // mask MSBs
}
#[test]
fn smoke_write() {
let mut buf = BitBuf::empty();
buf.write_u32_as_bits(89, 19).unwrap();
buf.write_u32_as_bits(42, 12).unwrap();
assert_eq!(buf.bytes(), [0x59, 0x00, 0x50, 0x01]);
}
#[test]
fn smoke_read() {
let contents = vec![0x59, 0x00, 0x50, 0x01];
let mut buf = BitBuf::new(contents, 31);
assert_eq!(buf.read_bits_as_u32(19), Ok(89));
assert_eq!(buf.read_bits_as_u32(12), Ok(42));
}
#[test]
fn roundtrip() {
let mut buf = BitBuf::empty();
buf.write_u32_as_bits(532, 23).unwrap();
buf.write_u32_as_bits(1, 1).unwrap();
buf.write_i32_as_bits(-98, 8).unwrap();
buf.write_u32_as_bits(9, 6).unwrap();
buf.seek(0);
assert_eq!(buf.read_bits_as_u32(23), Ok(532));
assert_eq!(buf.read_bits_as_u32(1), Ok(1));
assert_eq!(buf.read_bits_as_i32(8), Ok(-98));
assert_eq!(buf.read_bits_as_u32(6), Ok(9));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment