Last active
April 7, 2019 05:09
-
-
Save saka1/5bdd0f0ad4e498d22258982f6b9699c5 to your computer and use it in GitHub Desktop.
count UTF-8 codepoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
#include <iostream> | |
#include <algorithm> | |
#include <functional> | |
#include <cstdlib> | |
#include <sys/time.h> | |
#include <stdbool.h> | |
#include <string.h> | |
template<typename T> __attribute__((noinline)) void dump(__m256i v) | |
{ | |
using namespace std; | |
cout << "LSB "; | |
auto x = reinterpret_cast<T*>(&v); | |
for (size_t i = 0; i < sizeof(v)/sizeof(T); i++) { | |
cout << std::to_string(x[i]) << ", "; | |
} | |
cout << "MSB" << endl; | |
} | |
// https://gcc.gnu.org/onlinedocs/gcc-8.2.0/gcc/Machine-Constraints.html#Machine-Constraints | |
inline uint64_t rdtsc() | |
{ | |
uint32_t tickl, tickh; | |
__asm__ __volatile__("rdtsc":"=a"(tickl),"=d"(tickh)); | |
return (static_cast<uint64_t>(tickh) << 32) | tickl; | |
} | |
inline void time(std::function<void()> f) | |
{ | |
struct timeval s0, s1; | |
gettimeofday(&s0, NULL); | |
auto t0 = rdtsc(); | |
f(); | |
auto t1 = rdtsc(); | |
gettimeofday(&s1, NULL); | |
std::cout << "clock = " << (t1 - t0)/1000/1000 << " [Mclk]" << std::endl; | |
std::cout << "time = " << (s1.tv_sec - s0.tv_sec) + (s1.tv_usec - s0.tv_usec)*1.0E-6 << std::endl; | |
} | |
#define is_utf8_lead_byte(c) (((c)&0xC0) != 0x80) | |
#define NONASCII_MASK UINT64_C(0x8080808080808080) | |
inline uintptr_t count_utf8_lead_bytes_with_word(const uintptr_t *s) | |
{ | |
uintptr_t d = *s; | |
d = (d>>6) | (~d>>7); | |
d &= NONASCII_MASK >> 7; | |
return __builtin_popcountll(d); | |
} | |
#if defined(__AVX2__) | |
inline int32_t avx2_horizontal_sum_epi8(__m256i x) | |
{ | |
__m256i sumhi = _mm256_unpackhi_epi8(x, _mm256_setzero_si256()); | |
__m256i sumlo = _mm256_unpacklo_epi8(x, _mm256_setzero_si256()); | |
__m256i sum16x16 = _mm256_add_epi16(sumhi, sumlo); | |
__m256i sum16x8 = _mm256_add_epi16(sum16x16, _mm256_permute2x128_si256(sum16x16, sum16x16, 1)); | |
__m256i sum16x4 = _mm256_add_epi16(sum16x8, _mm256_shuffle_epi32(sum16x8, _MM_SHUFFLE(0, 0, 2, 3))); | |
uint64_t tmp = _mm256_extract_epi64(sum16x4, 0); | |
int32_t result = 0; | |
result += (tmp >> 0 ) & 0xffff; | |
result += (tmp >> 16) & 0xffff; | |
result += (tmp >> 32) & 0xffff; | |
result += (tmp >> 48) & 0xffff; | |
return result; | |
} | |
int64_t avx_count_utf8_codepoint(const char *p, const char *e) | |
{ | |
// `p` must be 32B-aligned pointer | |
p = static_cast<const char *>(__builtin_assume_aligned(p, 32)); | |
const size_t size = e - p; | |
int64_t result = 0; | |
for (size_t i = 0; i + 31 < size;) { | |
__m256i sum = _mm256_setzero_si256(); | |
size_t j = 0; | |
for (; j < 255 * 32 && (i + 31) + j < size; j += 32) { | |
const __m256i table = _mm256_setr_epi8( | |
1, 1, 1, 1, 1, 1, 1, 1, // .. 0x7 | |
0, 0, 0, 0, // 0x8 .. 0xB | |
1, 1, 1, 1, // 0xC .. 0xF | |
1, 1, 1, 1, 1, 1, 1, 1, // .. 0x7 | |
0, 0, 0, 0, // 0x8 .. 0xB | |
1, 1, 1, 1 // 0xC .. 0xF | |
); | |
__m256i s = _mm256_load_si256(reinterpret_cast<const __m256i *>(p + i + j)); | |
s = _mm256_and_si256(_mm256_srli_epi16(s, 4), _mm256_set1_epi8(0x0F)); | |
s = _mm256_shuffle_epi8(table, s); | |
sum = _mm256_add_epi8(sum, s); | |
} | |
i += j; | |
result += avx2_horizontal_sum_epi8(sum); | |
} | |
return result; | |
} | |
#endif | |
int64_t ruby_count_utf8_codepoint(const char *p, const char *e) | |
{ | |
uintptr_t len = 0; | |
if ((int)sizeof(uintptr_t) * 2 < e - p) { | |
const uintptr_t *s, *t; | |
const uintptr_t lowbits = sizeof(uintptr_t) - 1; | |
s = (const uintptr_t*)(~lowbits & ((uintptr_t)p + lowbits)); | |
t = (const uintptr_t*)(~lowbits & (uintptr_t)e); | |
while (p < (const char *)s) { | |
if (is_utf8_lead_byte(*p)) len++; | |
p++; | |
} | |
while (s < t) { | |
len += count_utf8_lead_bytes_with_word(s); | |
s++; | |
} | |
p = (const char *)s; | |
} | |
while (p < e) { | |
if (is_utf8_lead_byte(*p)) len++; | |
p++; | |
} | |
return (long)len; | |
} | |
__attribute__((noinline)) | |
int64_t count_utf8_codepoint(const char *p, const char *e) | |
{ | |
int64_t count = 0; | |
#if defined(__AVX2__) | |
if (32 <= e - p) { | |
// increment `p` to 32B boundary | |
while (((uintptr_t)p % 32) != 0) { | |
if (is_utf8_lead_byte(*p)) count++; | |
p++; | |
} | |
// vectorized count | |
count += avx_count_utf8_codepoint(p, e); | |
p += static_cast<uintptr_t>(e - p) / 32 * 32; | |
} | |
#endif | |
while (p < e) { | |
if (is_utf8_lead_byte(*p)) count++; | |
p++; | |
} | |
return count; | |
} | |
int64_t scalar_count_utf8_codepoint(const char *p, const char *e) | |
{ | |
int64_t count = 0; | |
while (p < e) { | |
if (is_utf8_lead_byte(*p)) count++; | |
p++; | |
} | |
return count; | |
} | |
uint32_t xor64(void) { | |
static uint64_t x = 88172645463325252ULL; | |
x = x ^ (x << 13); x = x ^ (x >> 7); | |
return x = x ^ (x << 17); | |
} | |
int main() | |
{ | |
constexpr auto N = 100 * 1024 * 1024; | |
char *p = (char *)calloc(N, sizeof(char)); | |
const char *e = p + N; | |
for (size_t i = 0; i < N; i++) { | |
p[i] = xor64(); | |
} | |
time([&p, &e]() noexcept { | |
volatile uint64_t tmp; // use `volatile` to supress optimization | |
for (int i = 0; i < 100; i++) { | |
//tmp = count_utf8_codepoint(p, e); | |
//tmp = ruby_count_utf8_codepoint(p, e); | |
tmp = scalar_count_utf8_codepoint(p, e); | |
} | |
printf("count = %lu\n", tmp); | |
}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment