Skip to content

Instantly share code, notes, and snippets.

@thomcc
Created November 9, 2019 17:53
Show Gist options
  • Save thomcc/ab4c9d4509e58912feaf0c538a071b81 to your computer and use it in GitHub Desktop.
Save thomcc/ab4c9d4509e58912feaf0c538a071b81 to your computer and use it in GitHub Desktop.
use core::cmp;
use core::usize;
#[cfg(target_pointer_width = "32")]
const USIZE_BYTES: usize = 4;
#[cfg(target_pointer_width = "64")]
const USIZE_BYTES: usize = 8;
// The number of bytes to loop at in one iteration of memchr/memrchr.
const LOOP_SIZE: usize = 1 * USIZE_BYTES;
const LO_U64: u64 = 0x0101010101010101;
const HI_U64: u64 = 0x8080808080808080;
const LO_USIZE: usize = LO_U64 as usize;
const HI_USIZE: usize = HI_U64 as usize;
// Copypasted from byteset/scalar.rs (see it for documentation).
#[inline(always)]
fn repeat_byte(b: u8) -> usize {
(b as usize) * (usize::MAX / 255)
}
// Each byte in the result will have a value of 0x01 if the corresponding byte in
// `x` was zero, and a value of 0 otherwise.
//
// The idea is the same as the `contains_zero_byte` function in memchr's
// fallback code:
// https://github.com/BurntSushi/rust-memchr/blob/b5c5cfe37207b00494597ac14a23e826fe8e59b1/src/fallback.rs#L26
#[inline(always)]
fn flag_zero_bytes(x: usize) -> usize {
(x.wrapping_sub(LO_USIZE) & !x & HI_USIZE) >> 7
}
pub(crate) fn byte_count(n1: u8, haystack: &[u8]) -> usize {
let vn1 = repeat_byte(n1);
let loop_size = cmp::min(LOOP_SIZE, haystack.len());
let align = USIZE_BYTES - 1;
let start_ptr = haystack.as_ptr();
let end_ptr = haystack[haystack.len()..].as_ptr();
let mut ptr = start_ptr;
let mut count = 0;
unsafe {
if haystack.len() < USIZE_BYTES {
return count_individual(start_ptr, end_ptr, ptr, n1);
}
// If needed, realign, recording any relevant bytes we skip over when
// doing so.
if (start_ptr as usize & align) != 0 {
ptr = start_ptr.add(USIZE_BYTES - (start_ptr as usize & align));
debug_assert!(ptr > start_ptr);
debug_assert!(ptr <= end_ptr);
debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
let chunk: usize = (start_ptr as *const usize).read_unaligned();
if flag_zero_bytes(chunk ^ vn1) != 0 {
// Note: there's a missed optimization here: we could use the return
// value of flag_zero_bytes here, so long as we only use the parts
// between `start_ptr` and `ptr`.
count += count_individual(start_ptr, ptr, start_ptr, n1);
}
}
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
// Essentially this is an array of `[u8; USIZE_BYTES]` which stores
// that many separate counts. To avoid overflowing any of these,
// every so often (when the MSB of any of the bytes in this value is
// set) we accumulate what we have so far and reset it.
let mut wide_accum = 0;
while (wide_accum & HI_USIZE) == 0 && ptr <= end_ptr.sub(loop_size) {
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
debug_assert!(ptr <= end_ptr);
let word1 = *(ptr as *const usize);
// let word2 = *(ptr.add(USIZE_BYTES) as *const usize);
wide_accum += flag_zero_bytes(word1 ^ vn1);
// wide_accum += flag_zero_bytes(word2 ^ vn1);
ptr = ptr.add(LOOP_SIZE);
}
count += horizontal_add(wide_accum);
}
count_individual(start_ptr, end_ptr, ptr, n1) + count
}
}
#[inline]
fn horizontal_add(wide_counts: usize) -> usize {
let p = &wide_counts as *const usize as *const u8;
let mut c = 0;
unsafe {
for i in 0..USIZE_BYTES {
c += *p.add(i) as usize;
}
}
c
}
#[inline]
unsafe fn count_individual(
start_ptr: *const u8,
end_ptr: *const u8,
mut ptr: *const u8,
byte: u8,
) -> usize {
debug_assert!(ptr >= start_ptr);
debug_assert!(ptr <= end_ptr);
let mut c = 0;
while ptr != end_ptr {
c += (*ptr == byte) as usize;
ptr = ptr.add(1);
}
c
}
#[cfg(test)]
mod tests {
use super::byte_count;
type TestCase = (Vec<u8>, usize);
#[test]
fn test_bytecount_sizes_and_aligns() {
// Check many sizes and alignements of slices going into byte count.
let r = b"aaaazzza".iter().copied().cycle().take(515).collect::<Vec<_>>();
for i in 0..r.len() {
for j in i..r.len() {
let s = &r[i..j];
let na = s.iter().filter(|&&b| b == b'a').count();
let nz = s.len() - na;
assert_eq!(byte_count(b'a', s), na);
assert_eq!(byte_count(b'z', s), nz);
assert_eq!(byte_count(b'x', s), 0);
}
}
}
#[test]
fn test_bytecount_no_wrong_answers() {
for byte in 0u8..=255u8 {
// should be big enough that if overflow were an issue we'd hit it
const SIZE: usize = 8192;
let fill = [byte; SIZE];
assert_eq!(byte_count(byte, &fill), SIZE);
for byte2 in 0u8..=255u8 {
if byte2 != byte {
assert_eq!(byte_count(byte2, &fill), 0);
}
}
}
}
#[test]
fn test_bytecount_no_overflow() {
// Do to the algorithm used, we want to be sure that if there are going
// to be problems due to overflowing our accumulators, we'd catch it.
// Exhaustively test the patterns that could occupy an (8 byte) usize to
// check this in all cases.
// let mut buf = vec![];
let pats = (0..256).flat_map(|pat| {
(0..8).map(move |bit| {
let mask = 1 << bit;
if (pat & mask) == 0 { b'0' } else { b'1' }
})
}).collect::<Vec<_>>();
let mut tests = vec![];
for i in 0..=256 {
tests.push(
std::iter::repeat(b'0')
.take(i)
.chain(pats.iter().cloned())
.collect::<Vec<_>>());
tests.push(
pats.iter()
.cloned()
.chain(std::iter::repeat(b'0').take(i))
.collect::<Vec<_>>());
tests.push(
std::iter::repeat(b'0')
.take(i)
.chain(pats.iter().cloned())
.chain(std::iter::repeat(b'0').take(i))
.collect::<Vec<_>>());
}
let num_ones = pats.iter().filter(|&&b| b == b'1').count();
for test in &tests {
let num_zeros = test.len() - num_ones;
let bc1s = byte_count(b'1', &test);
let bc0s = byte_count(b'0', &test);
assert_eq!(num_ones, bc1s);
assert_eq!(num_zeros, bc0s);
for offset in 0..130 {
let (pre, post) = test.split_at(offset);
let pre_ones = pre.iter().filter(|&&b| b == b'1').count();
let pre_zeros = pre.len() - pre_ones;
let post_ones = num_ones - pre_ones;
let post_zeros = num_zeros - pre_zeros;
let bc_pre1s = byte_count(b'1', pre);
let bc_pre0s = byte_count(b'0', pre);
assert_eq!(pre_ones, bc_pre1s);
assert_eq!(pre_zeros, bc_pre0s);
let bc_post1s = byte_count(b'1', post);
let bc_post0s = byte_count(b'0', post);
assert_eq!(post_ones, bc_post1s);
assert_eq!(post_zeros, bc_post0s);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment