Skip to content

Instantly share code, notes, and snippets.

@saka1
Last active April 7, 2019 05:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saka1/5bdd0f0ad4e498d22258982f6b9699c5 to your computer and use it in GitHub Desktop.
Save saka1/5bdd0f0ad4e498d22258982f6b9699c5 to your computer and use it in GitHub Desktop.
count UTF-8 codepoint
#include <immintrin.h>
#include <iostream>
#include <algorithm>
#include <functional>
#include <cstdlib>
#include <sys/time.h>
#include <stdbool.h>
#include <string.h>
template<typename T> __attribute__((noinline)) void dump(__m256i v)
{
using namespace std;
cout << "LSB ";
auto x = reinterpret_cast<T*>(&v);
for (size_t i = 0; i < sizeof(v)/sizeof(T); i++) {
cout << std::to_string(x[i]) << ", ";
}
cout << "MSB" << endl;
}
// https://gcc.gnu.org/onlinedocs/gcc-8.2.0/gcc/Machine-Constraints.html#Machine-Constraints
inline uint64_t rdtsc()
{
uint32_t tickl, tickh;
__asm__ __volatile__("rdtsc":"=a"(tickl),"=d"(tickh));
return (static_cast<uint64_t>(tickh) << 32) | tickl;
}
inline void time(std::function<void()> f)
{
struct timeval s0, s1;
gettimeofday(&s0, NULL);
auto t0 = rdtsc();
f();
auto t1 = rdtsc();
gettimeofday(&s1, NULL);
std::cout << "clock = " << (t1 - t0)/1000/1000 << " [Mclk]" << std::endl;
std::cout << "time = " << (s1.tv_sec - s0.tv_sec) + (s1.tv_usec - s0.tv_usec)*1.0E-6 << std::endl;
}
#define is_utf8_lead_byte(c) (((c)&0xC0) != 0x80)
#define NONASCII_MASK UINT64_C(0x8080808080808080)
inline uintptr_t count_utf8_lead_bytes_with_word(const uintptr_t *s)
{
uintptr_t d = *s;
d = (d>>6) | (~d>>7);
d &= NONASCII_MASK >> 7;
return __builtin_popcountll(d);
}
#if defined(__AVX2__)
inline int32_t avx2_horizontal_sum_epi8(__m256i x)
{
__m256i sumhi = _mm256_unpackhi_epi8(x, _mm256_setzero_si256());
__m256i sumlo = _mm256_unpacklo_epi8(x, _mm256_setzero_si256());
__m256i sum16x16 = _mm256_add_epi16(sumhi, sumlo);
__m256i sum16x8 = _mm256_add_epi16(sum16x16, _mm256_permute2x128_si256(sum16x16, sum16x16, 1));
__m256i sum16x4 = _mm256_add_epi16(sum16x8, _mm256_shuffle_epi32(sum16x8, _MM_SHUFFLE(0, 0, 2, 3)));
uint64_t tmp = _mm256_extract_epi64(sum16x4, 0);
int32_t result = 0;
result += (tmp >> 0 ) & 0xffff;
result += (tmp >> 16) & 0xffff;
result += (tmp >> 32) & 0xffff;
result += (tmp >> 48) & 0xffff;
return result;
}
int64_t avx_count_utf8_codepoint(const char *p, const char *e)
{
// `p` must be 32B-aligned pointer
p = static_cast<const char *>(__builtin_assume_aligned(p, 32));
const size_t size = e - p;
int64_t result = 0;
for (size_t i = 0; i + 31 < size;) {
__m256i sum = _mm256_setzero_si256();
size_t j = 0;
for (; j < 255 * 32 && (i + 31) + j < size; j += 32) {
const __m256i table = _mm256_setr_epi8(
1, 1, 1, 1, 1, 1, 1, 1, // .. 0x7
0, 0, 0, 0, // 0x8 .. 0xB
1, 1, 1, 1, // 0xC .. 0xF
1, 1, 1, 1, 1, 1, 1, 1, // .. 0x7
0, 0, 0, 0, // 0x8 .. 0xB
1, 1, 1, 1 // 0xC .. 0xF
);
__m256i s = _mm256_load_si256(reinterpret_cast<const __m256i *>(p + i + j));
s = _mm256_and_si256(_mm256_srli_epi16(s, 4), _mm256_set1_epi8(0x0F));
s = _mm256_shuffle_epi8(table, s);
sum = _mm256_add_epi8(sum, s);
}
i += j;
result += avx2_horizontal_sum_epi8(sum);
}
return result;
}
#endif
int64_t ruby_count_utf8_codepoint(const char *p, const char *e)
{
uintptr_t len = 0;
if ((int)sizeof(uintptr_t) * 2 < e - p) {
const uintptr_t *s, *t;
const uintptr_t lowbits = sizeof(uintptr_t) - 1;
s = (const uintptr_t*)(~lowbits & ((uintptr_t)p + lowbits));
t = (const uintptr_t*)(~lowbits & (uintptr_t)e);
while (p < (const char *)s) {
if (is_utf8_lead_byte(*p)) len++;
p++;
}
while (s < t) {
len += count_utf8_lead_bytes_with_word(s);
s++;
}
p = (const char *)s;
}
while (p < e) {
if (is_utf8_lead_byte(*p)) len++;
p++;
}
return (long)len;
}
__attribute__((noinline))
int64_t count_utf8_codepoint(const char *p, const char *e)
{
int64_t count = 0;
#if defined(__AVX2__)
if (32 <= e - p) {
// increment `p` to 32B boundary
while (((uintptr_t)p % 32) != 0) {
if (is_utf8_lead_byte(*p)) count++;
p++;
}
// vectorized count
count += avx_count_utf8_codepoint(p, e);
p += static_cast<uintptr_t>(e - p) / 32 * 32;
}
#endif
while (p < e) {
if (is_utf8_lead_byte(*p)) count++;
p++;
}
return count;
}
int64_t scalar_count_utf8_codepoint(const char *p, const char *e)
{
int64_t count = 0;
while (p < e) {
if (is_utf8_lead_byte(*p)) count++;
p++;
}
return count;
}
uint32_t xor64(void) {
static uint64_t x = 88172645463325252ULL;
x = x ^ (x << 13); x = x ^ (x >> 7);
return x = x ^ (x << 17);
}
int main()
{
constexpr auto N = 100 * 1024 * 1024;
char *p = (char *)calloc(N, sizeof(char));
const char *e = p + N;
for (size_t i = 0; i < N; i++) {
p[i] = xor64();
}
time([&p, &e]() noexcept {
volatile uint64_t tmp; // use `volatile` to supress optimization
for (int i = 0; i < 100; i++) {
//tmp = count_utf8_codepoint(p, e);
//tmp = ruby_count_utf8_codepoint(p, e);
tmp = scalar_count_utf8_codepoint(p, e);
}
printf("count = %lu\n", tmp);
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment