Skip to content

Instantly share code, notes, and snippets.

@cmyr
Created September 11, 2018 23:17
Show Gist options
  • Save cmyr/8b05bcd024934d1ea37c67ac64252e0a to your computer and use it in GitHub Desktop.
Save cmyr/8b05bcd024934d1ea37c67ac64252e0a to your computer and use it in GitHub Desktop.
Using simd in rust for suffix/prefix finding
const SSID_OPTS: i32 = _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_EACH | _SIDD_UNIT_MASK
| _SIDD_LEAST_SIGNIFICANT | _SIDD_NEGATIVE_POLARITY;
/// # Examples
///
/// ```
/// let one = "aaaaaaaaaaaaaaaa";
/// let two = "aa3aaaaa9aaaEaaa";
/// //NOTE: reversed
/// let exp = "0001000100000100";
/// let result = format!("{:16b}", fast_cmpestr_mask(one.as_bytes(), two.as_bytes()));
/// assert_eq!(result.as_str(), exp);
/// ```
///
#[inline(always)]
fn fast_cmpestr_mask(one: &[u8], two: &[u8]) -> i32 {
debug_assert_eq!(one.len(), 16);
debug_assert_eq!(two.len(), 16);
unsafe {
let onev = _mm_loadu_si128(one.as_ptr() as *const _);
let twov = _mm_loadu_si128(two.as_ptr() as *const _);
let mask = _mm_cmpestrm(onev, 16, twov, 16, SSID_OPTS);
_mm_movemask_epi8(mask)
}
}
const SSE_STRIDE: usize = 16;
/// Returns the lowest `i` for which `one[i] != two[i]`, if one exists.
fn ne_idx(one: &[u8], two: &[u8]) -> Option<usize> {
let min_len = one.len().min(two.len());
let mut idx = 0;
loop {
let mask: i32;
// if slice is less than 16 bytes we manually pad it
if idx + SSE_STRIDE >= min_len {
let mut one_buf: [u8; 16] = [0; 16];
let mut two_buf: [u8; 16] = [0; 16];
let mut temp_idx = 0;
for i in idx..min_len {
one_buf[temp_idx] = one[i];
two_buf[temp_idx] = two[i];
temp_idx += 1;
}
mask = fast_cmpestr_mask(&one_buf, &two_buf);
} else {
mask = fast_cmpestr_mask(&one[idx..idx+SSE_STRIDE], &two[idx..idx+SSE_STRIDE]);
}
let i = mask.trailing_zeros() as usize;
if i != 32 { return Some(idx + i); }
idx += SSE_STRIDE;
if idx >= min_len { break; }
}
None
}
/// Returns the lowest `i` such that `one[one.len()-i] != two[two.len()-i]`,
/// if one exists.
fn ne_idx_rev(one: &[u8], two: &[u8]) -> Option<usize> {
let min_len = one.len().min(two.len());
let mut idx = min_len;
loop {
let mask: i32;
if idx < SSE_STRIDE {
let mut one_buf: [u8; 16] = [0; 16];
let mut two_buf: [u8; 16] = [0; 16];
let mut temp_idx = SSE_STRIDE - idx;
for i in 0..idx {
one_buf[temp_idx] = one[i];
two_buf[temp_idx] = two[i];
temp_idx += 1;
}
mask = fast_cmpestr_mask(&one_buf, &two_buf);
} else {
mask = fast_cmpestr_mask(&one[idx-SSE_STRIDE..idx], &two[idx-SSE_STRIDE..idx]);
}
let i = mask.leading_zeros() as usize - 16;
if i != 16 { return Some(min_len - (idx - i)); }
if idx < SSE_STRIDE { break; }
idx -= SSE_STRIDE;
}
None
}
fn str_match_end_sw(one: &str, two: &str) -> Option<usize> {
let one = one.as_bytes();
let two = two.as_bytes();
for i in 0..one.len().min(two.len()) {
if one[i] != two[i] { return Some(i); }
}
None
//one.bytes().zip(two.bytes()).position(|(a, b)| a != b)
}
fn str_rev_match_end_sw(one: &str, two: &str) -> Option<usize> {
let one = one.as_bytes();
let two = two.as_bytes();
let min_len = one.len().min(two.len()) - 1;
let mut i = min_len;
while i >= 0 {
if one[i] != two[i] { return Some(min_len - i); }
i -= 1;
}
None
//one.bytes().zip(two.bytes()).position(|(a, b)| a != b)
}
/*
test bench_hw_200k ... bench: 35,689 ns/iter (+/- 735)
test bench_hw_50k ... bench: 8,935 ns/iter (+/- 192)
test bench_hw_rev_200k ... bench: 47,544 ns/iter (+/- 976)
test bench_hw_rev_50k ... bench: 38,113 ns/iter (+/- 1,023)
test bench_hw_small ... bench: 16 ns/iter (+/- 0)
test bench_sw_200k ... bench: 95,202 ns/iter (+/- 2,000)
test bench_sw_50k ... bench: 23,766 ns/iter (+/- 927)
test bench_sw_rev_200k ... bench: 117,995 ns/iter (+/- 2,955)
test bench_sw_rev_50k ... bench: 94,450 ns/iter (+/- 2,173)
test bench_sw_small ... bench: 5 ns/iter (+/- 1)
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment