-
-
Save thomcc/ab4c9d4509e58912feaf0c538a071b81 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use core::cmp; | |
use core::usize; | |
#[cfg(target_pointer_width = "32")] | |
const USIZE_BYTES: usize = 4; | |
#[cfg(target_pointer_width = "64")] | |
const USIZE_BYTES: usize = 8; | |
// The number of bytes to loop at in one iteration of memchr/memrchr. | |
const LOOP_SIZE: usize = 1 * USIZE_BYTES; | |
const LO_U64: u64 = 0x0101010101010101; | |
const HI_U64: u64 = 0x8080808080808080; | |
const LO_USIZE: usize = LO_U64 as usize; | |
const HI_USIZE: usize = HI_U64 as usize; | |
// Copypasted from byteset/scalar.rs (see it for documentation). | |
#[inline(always)] | |
fn repeat_byte(b: u8) -> usize { | |
(b as usize) * (usize::MAX / 255) | |
} | |
// Each byte in the result will have a value of 0x01 if the corresponding byte in | |
// `x` was zero, and a value of 0 otherwise. | |
// | |
// The idea is the same as the `contains_zero_byte` function in memchr's | |
// fallback code: | |
// https://github.com/BurntSushi/rust-memchr/blob/b5c5cfe37207b00494597ac14a23e826fe8e59b1/src/fallback.rs#L26 | |
#[inline(always)] | |
fn flag_zero_bytes(x: usize) -> usize { | |
(x.wrapping_sub(LO_USIZE) & !x & HI_USIZE) >> 7 | |
} | |
pub(crate) fn byte_count(n1: u8, haystack: &[u8]) -> usize { | |
let vn1 = repeat_byte(n1); | |
let loop_size = cmp::min(LOOP_SIZE, haystack.len()); | |
let align = USIZE_BYTES - 1; | |
let start_ptr = haystack.as_ptr(); | |
let end_ptr = haystack[haystack.len()..].as_ptr(); | |
let mut ptr = start_ptr; | |
let mut count = 0; | |
unsafe { | |
if haystack.len() < USIZE_BYTES { | |
return count_individual(start_ptr, end_ptr, ptr, n1); | |
} | |
// If needed, realign, recording any relevant bytes we skip over when | |
// doing so. | |
if (start_ptr as usize & align) != 0 { | |
ptr = start_ptr.add(USIZE_BYTES - (start_ptr as usize & align)); | |
debug_assert!(ptr > start_ptr); | |
debug_assert!(ptr <= end_ptr); | |
debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr); | |
let chunk: usize = (start_ptr as *const usize).read_unaligned(); | |
if flag_zero_bytes(chunk ^ vn1) != 0 { | |
// Note: there's a missed optimization here: we could use the return | |
// value of flag_zero_bytes here, so long as we only use the parts | |
// between `start_ptr` and `ptr`. | |
count += count_individual(start_ptr, ptr, start_ptr, n1); | |
} | |
} | |
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); | |
while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) { | |
// Essentially this is an array of `[u8; USIZE_BYTES]` which stores | |
// that many separate counts. To avoid overflowing any of these, | |
// every so often (when the MSB of any of the bytes in this value is | |
// set) we accumulate what we have so far and reset it. | |
let mut wide_accum = 0; | |
while (wide_accum & HI_USIZE) == 0 && ptr <= end_ptr.sub(loop_size) { | |
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); | |
debug_assert!(ptr <= end_ptr); | |
let word1 = *(ptr as *const usize); | |
// let word2 = *(ptr.add(USIZE_BYTES) as *const usize); | |
wide_accum += flag_zero_bytes(word1 ^ vn1); | |
// wide_accum += flag_zero_bytes(word2 ^ vn1); | |
ptr = ptr.add(LOOP_SIZE); | |
} | |
count += horizontal_add(wide_accum); | |
} | |
count_individual(start_ptr, end_ptr, ptr, n1) + count | |
} | |
} | |
#[inline] | |
fn horizontal_add(wide_counts: usize) -> usize { | |
let p = &wide_counts as *const usize as *const u8; | |
let mut c = 0; | |
unsafe { | |
for i in 0..USIZE_BYTES { | |
c += *p.add(i) as usize; | |
} | |
} | |
c | |
} | |
#[inline] | |
unsafe fn count_individual( | |
start_ptr: *const u8, | |
end_ptr: *const u8, | |
mut ptr: *const u8, | |
byte: u8, | |
) -> usize { | |
debug_assert!(ptr >= start_ptr); | |
debug_assert!(ptr <= end_ptr); | |
let mut c = 0; | |
while ptr != end_ptr { | |
c += (*ptr == byte) as usize; | |
ptr = ptr.add(1); | |
} | |
c | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::byte_count; | |
type TestCase = (Vec<u8>, usize); | |
#[test] | |
fn test_bytecount_sizes_and_aligns() { | |
// Check many sizes and alignements of slices going into byte count. | |
let r = b"aaaazzza".iter().copied().cycle().take(515).collect::<Vec<_>>(); | |
for i in 0..r.len() { | |
for j in i..r.len() { | |
let s = &r[i..j]; | |
let na = s.iter().filter(|&&b| b == b'a').count(); | |
let nz = s.len() - na; | |
assert_eq!(byte_count(b'a', s), na); | |
assert_eq!(byte_count(b'z', s), nz); | |
assert_eq!(byte_count(b'x', s), 0); | |
} | |
} | |
} | |
#[test] | |
fn test_bytecount_no_wrong_answers() { | |
for byte in 0u8..=255u8 { | |
// should be big enough that if overflow were an issue we'd hit it | |
const SIZE: usize = 8192; | |
let fill = [byte; SIZE]; | |
assert_eq!(byte_count(byte, &fill), SIZE); | |
for byte2 in 0u8..=255u8 { | |
if byte2 != byte { | |
assert_eq!(byte_count(byte2, &fill), 0); | |
} | |
} | |
} | |
} | |
#[test] | |
fn test_bytecount_no_overflow() { | |
// Do to the algorithm used, we want to be sure that if there are going | |
// to be problems due to overflowing our accumulators, we'd catch it. | |
// Exhaustively test the patterns that could occupy an (8 byte) usize to | |
// check this in all cases. | |
// let mut buf = vec![]; | |
let pats = (0..256).flat_map(|pat| { | |
(0..8).map(move |bit| { | |
let mask = 1 << bit; | |
if (pat & mask) == 0 { b'0' } else { b'1' } | |
}) | |
}).collect::<Vec<_>>(); | |
let mut tests = vec![]; | |
for i in 0..=256 { | |
tests.push( | |
std::iter::repeat(b'0') | |
.take(i) | |
.chain(pats.iter().cloned()) | |
.collect::<Vec<_>>()); | |
tests.push( | |
pats.iter() | |
.cloned() | |
.chain(std::iter::repeat(b'0').take(i)) | |
.collect::<Vec<_>>()); | |
tests.push( | |
std::iter::repeat(b'0') | |
.take(i) | |
.chain(pats.iter().cloned()) | |
.chain(std::iter::repeat(b'0').take(i)) | |
.collect::<Vec<_>>()); | |
} | |
let num_ones = pats.iter().filter(|&&b| b == b'1').count(); | |
for test in &tests { | |
let num_zeros = test.len() - num_ones; | |
let bc1s = byte_count(b'1', &test); | |
let bc0s = byte_count(b'0', &test); | |
assert_eq!(num_ones, bc1s); | |
assert_eq!(num_zeros, bc0s); | |
for offset in 0..130 { | |
let (pre, post) = test.split_at(offset); | |
let pre_ones = pre.iter().filter(|&&b| b == b'1').count(); | |
let pre_zeros = pre.len() - pre_ones; | |
let post_ones = num_ones - pre_ones; | |
let post_zeros = num_zeros - pre_zeros; | |
let bc_pre1s = byte_count(b'1', pre); | |
let bc_pre0s = byte_count(b'0', pre); | |
assert_eq!(pre_ones, bc_pre1s); | |
assert_eq!(pre_zeros, bc_pre0s); | |
let bc_post1s = byte_count(b'1', post); | |
let bc_post0s = byte_count(b'0', post); | |
assert_eq!(post_ones, bc_post1s); | |
assert_eq!(post_zeros, bc_post0s); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment