Skip to content

Instantly share code, notes, and snippets.

@ned14
Created November 9, 2022 17:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ned14/617ce47171c6324fc388306e5c141633 to your computer and use it in GitHub Desktop.
Save ned14/617ce47171c6324fc388306e5c141633 to your computer and use it in GitHub Desktop.
Many SIMD ways of finding the last zero byte in a fixed length string
#include <cassert>
#include <chrono>
#include <climits>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <span>
#include <vector>
/*
Straight C: 5.9, 84.37
SSE2: 4.38, 228
bitscan: 4.8, 86.3
*/
#define BOOST_CHECK(...) if(!(__VA_ARGS__)) {fprintf(stderr, "!(" #__VA_ARGS__ ")\n"); abort(); }
namespace utils {
/*! \class small_prng
\brief From http://burtleburtle.net/bob/rand/smallprng.html, a not awful fast random number source.
*/
class small_prng {
protected:
uint32_t a;
uint32_t b;
uint32_t c;
uint32_t d;
static inline uint32_t rot(uint32_t x, uint32_t k) noexcept { return (((x) << (k)) | ((x) >> (32 - (k)))); }
public:
//! The type produced by the small prng
using value_type = uint32_t;
//! Construct an instance with `seed`
explicit small_prng(uint32_t seed = 0xdeadbeef) noexcept {
a = 0xf1ea5eed;
b = c = d = seed;
for(size_t i = 0; i < 20; ++i)
(*this)();
}
//! Return `value_type` of pseudo-randomness
inline uint32_t operator()() noexcept {
uint32_t e = a - rot(b, 27);
a = b ^ rot(c, 17);
b = c + d;
c = d + e;
d = e + a;
return d;
}
};
}
namespace mdx {
/*! \brief A binary symbol identifier.
This is a binary symbol identifier of up to 24 bytes in length, padded to
the right with all bits zero bytes. Any contiguous run of all bits zero
bytes to the right are used to compress storage.
The binary symbol identifier may contain all bits zero bytes. `size()` works
exclusively with the last non-zero byte.
*/
struct symbol {
// longest name currently possible is ISIN at 22 bytes
char name[24];
constexpr symbol() noexcept
: name{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} {}
//! Construct from a character array and length
symbol(const char* s, size_t l) noexcept
: symbol() {
if(l > sizeof(name)) {
abort();
}
memcpy(name, s, l);
}
//! Construct from a byte array and length
// symbol(const byte* s, size_t l) noexcept
// : symbol() {
// if(l > sizeof(name)) {
// abort();
// }
// memcpy(name, s, l);
// }
//! Construct from a span of byte equivalents
// template <concepts::byte_equivalent T>
// explicit symbol(span<const T> s) noexcept
// : symbol(s.data(), s.size()) {}
//! Construct from a string view of characters
// explicit symbol(string_view s) noexcept
// : symbol(s.data(), s.size()) {}
//! Construct from a string literal
// template <concepts::byte_equivalent T, size_t N>
// explicit symbol(const T (&arr)[N]) noexcept
// : symbol(reinterpret_cast<const char*>(arr), N - std::is_same_v<std::decay_t<T>, char>) {}
//! Equality
bool operator==(const symbol& o) const noexcept { return 0 == memcmp(name, o.name, sizeof(name)); }
//! Inequality
bool operator!=(const symbol& o) const noexcept { return 0 != memcmp(name, o.name, sizeof(name)); }
//! Ordering
bool operator<(const symbol& o) const noexcept { return memcmp(name, o.name, sizeof(name)) < 0; }
//! Returns a pointer to the beginning of the binary symbol identifier
char* data() noexcept { return name; }
//! \overload
const char* data() const noexcept { return name; }
/*! Returns the index of the null byte after the very final non-null byte up to `limit`, or `limit` if
`name[limit - 1] != 0`.
Make SURE that the 24 bytes after `this` are valid to read from before calling this function (the
implementation uses SIMD to load either 24 or 16 bytes in a single cycle).
*/
inline size_t
size_within_maximum_length(size_t limit) const noexcept; // implemented in day_exchange_file_updater.cpp
//! Returns the index of the null byte after the very final non-null byte, or `sizeof(name)` if the last byte is
//! non-null.
size_t size() const noexcept { return size_within_maximum_length(sizeof(name)); }
//! Returns true if the symbol's name has zero length.
[[nodiscard]] bool empty() const noexcept { return size() == 0; }
//! Returns the symbol's name as a string view (note it may contain unprintable characters).
//string_view as_string_view() const& noexcept { return string_view(name, size()); }
//string_view as_string_view() && = delete;
//string_view as_string_view() const&& = delete;
};
}
#if 0 // defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64)
#include <emmintrin.h> // for SIMD memrchr implementation
namespace mdx {
size_t symbol::size_within_maximum_length(size_t limit) const noexcept {
if(limit > sizeof(name)) {
limit = sizeof(name);
}
if(name[limit - 1] != 0) {
return limit;
}
auto bsr = [](int value) -> unsigned {
#ifdef _MSC_VER
unsigned long bitpos;
_BitScanReverse(&bitpos, value);
return bitpos;
#elif defined(__GNUC__)
return (sizeof(unsigned) * CHAR_BIT - 1) - (unsigned) __builtin_clz(value);
#else
#error Unknown compiler
#endif
};
const __m128i zeros = _mm_setzero_si128();
if(limit > 16) {
// We are always 24 bytes long. That is 1.5 SSE registers.
const __m128i back = _mm_loadu_si128((const __m128i*) (name + 8)); // don't spill past 24 bytes
const __m128i front = _mm_loadu_si128((const __m128i*) (name + 0));
const __m128i cback = _mm_cmpeq_epi8(back, zeros);
const __m128i cfront = _mm_cmpeq_epi8(front, zeros);
uint16_t mback = ~(uint16_t) _mm_movemask_epi8(cback);
const uint8_t mfront = ~(uint8_t) _mm_movemask_epi8(cfront);
if(limit < 24) {
mback = uint16_t(mback << (24 - limit)) >> (24 - limit);
}
unsigned lastzeroidx = (mback != 0) ? (9 + bsr(mback)) : ((mfront != 0) ? (1 + bsr(mfront)) : 0);
if(lastzeroidx > limit) {
abort();
}
return lastzeroidx;
} else {
const __m128i content = _mm_loadu_si128((const __m128i*) (name + 0));
const __m128i cmp = _mm_cmpeq_epi8(content, zeros);
uint16_t nonzeros = ~(uint16_t) _mm_movemask_epi8(cmp);
if(limit < 16) {
nonzeros = uint16_t(nonzeros << (16 - limit)) >> (16 - limit);
}
unsigned lastzeroidx = (nonzeros != 0) ? (1 + bsr(nonzeros)) : 0;
if(lastzeroidx > limit) {
abort();
}
return lastzeroidx;
}
}
} // namespace mdx
#elif defined(__aarch64__) || defined(_M_ARM64)
#include <arm_neon.h> // for SIMD memrchr implementation
namespace mdx {
size_t symbol::size_within_maximum_length(size_t limit) const noexcept {
if(limit > sizeof(name)) {
limit = sizeof(name);
}
if(name[limit - 1] != 0) {
return limit;
}
auto bsf = [](int value) -> unsigned {
#ifdef _MSC_VER
unsigned long bitpos;
_BitScanForward(&bitpos, value);
return bitpos;
#elif defined(__GNUC__)
return (unsigned) __builtin_ctz(value);
#else
#error Unknown compiler
#endif
};
// Only use four bits per lane as we'll use pairwise addition to collapse the lanes into a byte
const uint8x16_t mask_front = [&] {
union{uint8x16_t ret; uint8_t ret_bytes[16];};
ret = (uint8x16_t) vdupq_n_u64(0x0102040810204080ULL);
if(limit < 16) {
memset(ret_bytes+limit, 0, 16-limit);
}
return ret;
}();
const uint8x8_t mask_back = [&] {
union{uint8x8_t ret; uint8_t ret_bytes[8];};
if(limit<16)
{
ret = vcreate_u8(0);
} else {
ret = vcreate_u8(0x0102040810204080ULL);
if(limit < 24) {
memset(ret_bytes+limit-16, 0, 24-limit);
}
}
return ret;
}();
if(limit > 16) {
// We are always 24 bytes long. NEON unlike SSE can do half registers.
const uint8x8_t back = vld1_u8((const uint8_t*) (name + 16));
const uint8x16_t front = vld1q_u8((const uint8_t*) (name + 0));
const uint8x8_t cback = vand_u8(vtst_u8(back, back), mask_back);
const uint8x16_t cfront = vandq_u8(vtstq_u8(front, front), mask_front);
// mask bits will be set in lanes where there was a non-zero byte.
// Now add adjacent lanes to reduce 24 lanes to three lanes (i.e. one bit per whether lane was non-zero)
const uint8x8_t rback = vpadd_u8(cback, cback); // 8 to 4
const uint8x8_t rfront = vget_low_u8(vpaddq_u8(cfront, cfront)); // 16 to 8
const uint8x16_t r1 = vcombine_u8(rfront, rback); // 4 + 8 = 12
const uint8x8_t r2 = vget_low_u8(vpaddq_u8(r1, r1)); // 12 to 6
union {
uint32_t r3_uints[2];
uint8x8_t r3;
};
r3 = vpadd_u8(r2, r2); // 6 to 3
#if 0
auto dump = [](const char *desc, auto x){ printf("%s:", desc); for(size_t n=0; n<sizeof(x); n++){ printf(" %.2x", ((const uint8_t *)&x)[n]);} printf("\n");};
dump(" front", front);
dump(" cfront", cfront);
dump(" rfront", rfront);
printf("\n");
dump(" back", back);
dump(" cback", cback);
dump(" rback", rback);
printf("\n");
dump(" r1", r1);
dump(" r2", r2);
dump(" r3", r3);
#endif
const uint32_t m = __builtin_bswap32(r3_uints[0])>>8;
//printf("m=%x\n", m);
// If the entire input had no zero bytes, m would be 0xff
// If the entire input were zero bytes, m would be 0x00
unsigned lastzeroidx = 0;
if(m!=0) {
lastzeroidx = 24 - bsf(m);
}
if(lastzeroidx > limit) {
abort();
}
//printf("lastzeroidx=%u\n", lastzeroidx);
return lastzeroidx;
} else {
const uint8x16_t content = vld1q_u8((const uint8_t*) (name + 0));
const uint8x16_t cmp = vandq_u8(vtstq_u8(content, content), mask_front);
const uint8x8_t r1 = vget_low_u8(vpaddq_u8(cmp, cmp)); // 16 to 8
const uint8x8_t r2 = vpadd_u8(r1, r1); // 8 to 4
union {
uint16_t r3_uints[4];
uint8x8_t r3;
};
r3 = vpadd_u8(r2, r2); // 4 to 2
const uint16_t m = __builtin_bswap16(r3_uints[0]);
unsigned lastzeroidx = 0;
if(m!=0) {
lastzeroidx = 16 - bsf(m);
}
if(lastzeroidx > limit) {
abort();
}
return lastzeroidx;
}
}
} // namespace mdx
#elif 1
namespace mdx {
size_t symbol::size_within_maximum_length(size_t limit) const noexcept {
if(limit > sizeof(name)) {
limit = sizeof(name);
}
if(name[limit - 1] != 0) {
return limit;
}
auto bsr = [](uint64_t value) -> unsigned {
#ifdef _MSC_VER
unsigned long bitpos;
63 - _BitScanReverse64(&bitpos, value);
return bitpos;
#elif defined(__GNUC__)
return __builtin_clzll(value);
#else
#error Unknown compiler
#endif
};
const uint64_t *v = (const uint64_t *) name;
if(limit > 16 && v[2]!=0)
{
auto x =v[2];
if(limit < 24) {
const auto shift = (24-limit)<<3;
x = (x << shift) >> (shift);
}
if(x!=0)
{
return 24 - (bsr(x)>>3);
}
}
if(limit > 8 && v[1]!=0)
{
auto x =v[1];
if(limit < 16) {
const auto shift = (16-limit)<<3;
x = (x << shift) >> (shift);
}
if(x!=0)
{
return 16 - (bsr(x)>>3);
}
}
if(limit > 0 && v[0]!=0)
{
auto x =v[0];
if(limit < 8) {
const auto shift = (8-limit)<<3;
x = (x << shift) >> (shift);
}
if(x!=0)
{
return 8 - (bsr(x)>>3);
}
}
return 0;
}
} // namespace mdx
#else
namespace mdx {
size_t symbol::size_within_maximum_length(size_t limit) const noexcept {
if(limit > sizeof(name)) {
limit = sizeof(name);
}
if(name[limit - 1] != 0) {
return limit;
}
while(limit-- > 0) {
if(name[limit] != 0) {
return limit + 1;
}
}
return 0;
}
} // namespace mdx
#endif
int main(void)
{
using namespace mdx;
using mdx::symbol;
utils::small_prng rand;
{
auto begin = std::chrono::high_resolution_clock::now();
while(std::chrono::high_resolution_clock::now()-begin<std::chrono::seconds(1));
}
std::vector<std::pair<symbol, size_t>> symbols(50000000);
for(size_t n = 0; n < 50000000; n++) {
auto* c = (unsigned*) symbols[n].first.name;
c[0] = rand();
c[1] = rand();
c[2] = rand();
c[3] = rand();
c[4] = rand();
c[5] = rand();
auto r = rand();
auto &s=symbols[n].first;
// Place a random zero byte somewhere
s.name[r % 24] = 0;
//printf("\n\nzero byte set at %u\n", r % 24);
// Make some number of the end zero bytes
r >>= 28;
auto l = 24 - r;
memset(s.name + l, 0, r);
while(s.name[l - 1] == 0) {
r++;
l--;
}
symbols[n].second=l;
}
auto begin = std::chrono::high_resolution_clock::now();
for(size_t n = 0; n < 500000000; n++) {
auto &s=symbols[n % 50000000].first;
auto l=symbols[n % 50000000].second;
//printf("\nlength: %zu\n", l);
BOOST_CHECK(s.size() == l);
#if 1
BOOST_CHECK(s.size_within_maximum_length(l + 1) == l);
BOOST_CHECK(s.size_within_maximum_length(l) == l);
BOOST_CHECK(s.size_within_maximum_length(l - 1) <= l - 1);
#endif
}
auto end = std::chrono::high_resolution_clock::now();
auto diff = std::chrono::duration_cast<std::chrono::milliseconds>(end-begin);
printf("%f\n", diff.count()/1000.0);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment