Skip to content

Instantly share code, notes, and snippets.

@Trass3r
Created March 6, 2018 20:14
Show Gist options
  • Save Trass3r/b1e16c8814fba14d5226abd842a05b83 to your computer and use it in GitHub Desktop.
Save Trass3r/b1e16c8814fba14d5226abd842a05b83 to your computer and use it in GitHub Desktop.
efficient bitstream reader
#pragma once
#include <cstdint>
#include <cstddef>
#include <cassert>
#if __AVX2__
#include <immintrin.h>
#endif
//! modular bitstream reader (LSB order)
struct BitReader final
{
// renormalization bounds
static constexpr uint32_t lowerBound = (sizeof(size_t) - 1) * 8;
static constexpr uint32_t upperBound = sizeof(size_t) * 8 - 1;
BitReader(const void* ptr, size_t length)
: ptr((const uint8_t*)ptr)
, end((const uint8_t*)ptr + length)
{}
BitReader(const BitReader&) = delete;
void operator=(const BitReader&) = delete;
//! load new bits from the stream to renormalize buffer
void refill()
{
// this could be removed if the input array was properly padded
if (ptr > end - sizeof(size_t))
{
while (ptr < end && numBits < lowerBound)
{
buffer |= (size_t)*ptr << numBits;
++ptr;
numBits += 8;
}
return;
}
// assumption: unaligned access is cheap
// N.B.: load address does not depend on numBits
size_t data = *reinterpret_cast<const size_t*>(ptr);
buffer |= data << numBits;
ptr += (upperBound - numBits) / 8;
// increase numBits by the amount of full bytes read
numBits |= lowerBound;
assert(numBits >= lowerBound && numBits <= upperBound);
}
//! get a signed integer with the given amount of bits
intptr_t peeki(uint32_t n) const
{
assert(n <= numBits && n <= lowerBound);
// sign-extend
static_assert((3 << 30) >> 30 == -1, "right shift is not arithmetic");
uint8_t m = uint8_t(sizeof(size_t) * 8 - n);
return (static_cast<intptr_t>(buffer) << m) >> m;
}
//! get an unsigned integer with the given amount of bits
size_t peek(uint32_t n) const
{
assert(n <= lowerBound && n <= numBits);
#if __AVX2__ && (__x86_64__ || _M_X64)
return _bzhi_u64(buffer, n);
#elif __AVX2__
return _bzhi_u32(buffer, n);
#else
return buffer & ~(~size_t(0) << n);
#endif
}
//! shift out the given amount of bits
void advance(uint32_t n)
{
assert(n <= numBits);
buffer >>= n;
numBits -= n;
}
//! combines peek and advance for convenience
auto getBits(uint32_t n)
{
auto r = peek(n);
advance(n);
return r;
}
//! combines peeki and advance for convenience
auto getBitsi(uint32_t n)
{
auto r = peeki(n);
advance(n);
return r;
}
private:
const uint8_t* ptr = nullptr;
const uint8_t* end = nullptr;
size_t buffer = 0;
uint32_t numBits = 0; // number of bits contained in the buffer at the LSB end
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment