Created
September 11, 2018 23:17
-
-
Save cmyr/8b05bcd024934d1ea37c67ac64252e0a to your computer and use it in GitHub Desktop.
Using simd in rust for suffix/prefix finding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const SSID_OPTS: i32 = _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_EACH | _SIDD_UNIT_MASK | |
| _SIDD_LEAST_SIGNIFICANT | _SIDD_NEGATIVE_POLARITY; | |
/// # Examples | |
/// | |
/// ``` | |
/// let one = "aaaaaaaaaaaaaaaa"; | |
/// let two = "aa3aaaaa9aaaEaaa"; | |
/// //NOTE: reversed | |
/// let exp = "0001000100000100"; | |
/// let result = format!("{:16b}", fast_cmpestr_mask(one.as_bytes(), two.as_bytes())); | |
/// assert_eq!(result.as_str(), exp); | |
/// ``` | |
/// | |
#[inline(always)] | |
fn fast_cmpestr_mask(one: &[u8], two: &[u8]) -> i32 { | |
debug_assert_eq!(one.len(), 16); | |
debug_assert_eq!(two.len(), 16); | |
unsafe { | |
let onev = _mm_loadu_si128(one.as_ptr() as *const _); | |
let twov = _mm_loadu_si128(two.as_ptr() as *const _); | |
let mask = _mm_cmpestrm(onev, 16, twov, 16, SSID_OPTS); | |
_mm_movemask_epi8(mask) | |
} | |
} | |
const SSE_STRIDE: usize = 16; | |
/// Returns the lowest `i` for which `one[i] != two[i]`, if one exists. | |
fn ne_idx(one: &[u8], two: &[u8]) -> Option<usize> { | |
let min_len = one.len().min(two.len()); | |
let mut idx = 0; | |
loop { | |
let mask: i32; | |
// if slice is less than 16 bytes we manually pad it | |
if idx + SSE_STRIDE >= min_len { | |
let mut one_buf: [u8; 16] = [0; 16]; | |
let mut two_buf: [u8; 16] = [0; 16]; | |
let mut temp_idx = 0; | |
for i in idx..min_len { | |
one_buf[temp_idx] = one[i]; | |
two_buf[temp_idx] = two[i]; | |
temp_idx += 1; | |
} | |
mask = fast_cmpestr_mask(&one_buf, &two_buf); | |
} else { | |
mask = fast_cmpestr_mask(&one[idx..idx+SSE_STRIDE], &two[idx..idx+SSE_STRIDE]); | |
} | |
let i = mask.trailing_zeros() as usize; | |
if i != 32 { return Some(idx + i); } | |
idx += SSE_STRIDE; | |
if idx >= min_len { break; } | |
} | |
None | |
} | |
/// Returns the lowest `i` such that `one[one.len()-i] != two[two.len()-i]`, | |
/// if one exists. | |
fn ne_idx_rev(one: &[u8], two: &[u8]) -> Option<usize> { | |
let min_len = one.len().min(two.len()); | |
let mut idx = min_len; | |
loop { | |
let mask: i32; | |
if idx < SSE_STRIDE { | |
let mut one_buf: [u8; 16] = [0; 16]; | |
let mut two_buf: [u8; 16] = [0; 16]; | |
let mut temp_idx = SSE_STRIDE - idx; | |
for i in 0..idx { | |
one_buf[temp_idx] = one[i]; | |
two_buf[temp_idx] = two[i]; | |
temp_idx += 1; | |
} | |
mask = fast_cmpestr_mask(&one_buf, &two_buf); | |
} else { | |
mask = fast_cmpestr_mask(&one[idx-SSE_STRIDE..idx], &two[idx-SSE_STRIDE..idx]); | |
} | |
let i = mask.leading_zeros() as usize - 16; | |
if i != 16 { return Some(min_len - (idx - i)); } | |
if idx < SSE_STRIDE { break; } | |
idx -= SSE_STRIDE; | |
} | |
None | |
} | |
fn str_match_end_sw(one: &str, two: &str) -> Option<usize> { | |
let one = one.as_bytes(); | |
let two = two.as_bytes(); | |
for i in 0..one.len().min(two.len()) { | |
if one[i] != two[i] { return Some(i); } | |
} | |
None | |
//one.bytes().zip(two.bytes()).position(|(a, b)| a != b) | |
} | |
fn str_rev_match_end_sw(one: &str, two: &str) -> Option<usize> { | |
let one = one.as_bytes(); | |
let two = two.as_bytes(); | |
let min_len = one.len().min(two.len()) - 1; | |
let mut i = min_len; | |
while i >= 0 { | |
if one[i] != two[i] { return Some(min_len - i); } | |
i -= 1; | |
} | |
None | |
//one.bytes().zip(two.bytes()).position(|(a, b)| a != b) | |
} | |
/* | |
test bench_hw_200k ... bench: 35,689 ns/iter (+/- 735) | |
test bench_hw_50k ... bench: 8,935 ns/iter (+/- 192) | |
test bench_hw_rev_200k ... bench: 47,544 ns/iter (+/- 976) | |
test bench_hw_rev_50k ... bench: 38,113 ns/iter (+/- 1,023) | |
test bench_hw_small ... bench: 16 ns/iter (+/- 0) | |
test bench_sw_200k ... bench: 95,202 ns/iter (+/- 2,000) | |
test bench_sw_50k ... bench: 23,766 ns/iter (+/- 927) | |
test bench_sw_rev_200k ... bench: 117,995 ns/iter (+/- 2,955) | |
test bench_sw_rev_50k ... bench: 94,450 ns/iter (+/- 2,173) | |
test bench_sw_small ... bench: 5 ns/iter (+/- 1) | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment