Last active
February 8, 2020 16:31
-
-
Save notnullnotvoid/a4ebf552badf99395bcbd08e0fd2ded6 to your computer and use it in GitHub Desktop.
Optimizing simultaneous word and line counting for HMN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <assert.h> | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <stdint.h> | |
#include <string.h> | |
//NOTE: the correct behavior of this function is unfortunately not guaranteed by the standard | |
static char * read_entire_file(const char * filepath) { | |
FILE * f = fopen(filepath, "rb"); | |
assert(f); | |
fseek(f, 0, SEEK_END); | |
long fsize = ftell(f); | |
fseek(f, 0, SEEK_SET); //same as rewind(f); | |
char * string = (char *) malloc(fsize + 1); | |
fread(string, fsize, 1, f); | |
fclose(f); | |
string[fsize] = 0; | |
return string; | |
} | |
static const unsigned char spacemap[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
1, }; | |
static const unsigned char nlmap[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, }; | |
#define ISSP(x) (spacemap[x]) | |
#define ISNL(x) (nlmap[x]) | |
static unsigned char lookup[256]; | |
static void precompute() { | |
for (int i = 0; i < 256; i++) { | |
lookup[i] = 0; | |
int wassp = 0; | |
for (int j = 0; j < 8; j++) { | |
int issp = (i >> j) & 1; | |
if (wassp && !issp) | |
lookup[i]++; | |
wassp = issp; | |
} | |
} | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// BASIC /// | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
void count_basic(char * text, size_t len) { | |
size_t numWords = 0; | |
size_t numLines = 0; | |
size_t wasSp = 1; | |
for (size_t byte = 0; byte < len; ++byte) { | |
uint8_t c = text[byte]; | |
if (wasSp) numWords += !ISSP(c); | |
numLines += ISNL(c); | |
wasSp = ISSP(c); | |
} | |
printf("lines, words: %zu, %zu\n", numLines, numWords); | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// BASIC /// | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
void count_branchless(char * text, size_t len) { | |
size_t numWords = 0; | |
size_t numLines = 0; | |
size_t wasSp = 1; | |
for (size_t byte = 0; byte < len; ++byte) { | |
uint8_t c = text[byte]; | |
numWords += wasSp & (!ISSP(c)); | |
numLines += ISNL(c); | |
wasSp = ISSP(c); | |
} | |
printf("lines, words: %zu, %zu\n", numLines, numWords); | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// ORIGINAL /// | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
void count_original(char * text, size_t len) { | |
size_t numWords = 0; | |
size_t numLines = 0; | |
size_t wasSp = 1; | |
size_t byte = 0; | |
for (; byte < len - 7; byte += 8) { | |
uint64_t x = *((uint64_t *) (text + byte)); | |
uint64_t a0 = (x >> 0) & 0xff; | |
uint64_t a1 = (x >> 8) & 0xff; | |
uint64_t a2 = (x >> 16) & 0xff; | |
uint64_t a3 = (x >> 24) & 0xff; | |
uint64_t a4 = (x >> 32) & 0xff; | |
uint64_t a5 = (x >> 40) & 0xff; | |
uint64_t a6 = (x >> 48) & 0xff; | |
uint64_t a7 = (x >> 56) & 0xff; | |
numLines += ISNL(a0); | |
numLines += ISNL(a1); | |
numLines += ISNL(a2); | |
numLines += ISNL(a3); | |
numLines += ISNL(a4); | |
numLines += ISNL(a5); | |
numLines += ISNL(a6); | |
numLines += ISNL(a7); | |
if (wasSp) numWords += !ISSP(a0); | |
wasSp = ISSP(a7); | |
uint64_t spaces = | |
(ISSP(a0) << 0) | | |
(ISSP(a1) << 1) | | |
(ISSP(a2) << 2) | | |
(ISSP(a3) << 3) | | |
(ISSP(a4) << 4) | | |
(ISSP(a5) << 5) | | |
(ISSP(a6) << 6) | | |
(ISSP(a7) << 7); | |
numWords += lookup[spaces]; | |
} | |
for (; byte < len; ++byte) { | |
unsigned char c = text[byte]; | |
if (wasSp) numWords += !ISSP(c); | |
numLines += ISNL(c); | |
wasSp = ISSP(c); | |
} | |
printf("lines, words: %zu, %zu\n", numLines, numWords); | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// IMPROVED /// | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
void count_improved(char * text, size_t len) { | |
size_t numWords = 0; | |
size_t numLines = 0; | |
size_t wasSp = 1; | |
size_t byte = 0; | |
for (; byte < len - 7; byte += 8) { | |
uint8_t * c = (uint8_t *) (text + byte); | |
numLines += ISNL(c[0]); | |
numLines += ISNL(c[1]); | |
numLines += ISNL(c[2]); | |
numLines += ISNL(c[3]); | |
numLines += ISNL(c[4]); | |
numLines += ISNL(c[5]); | |
numLines += ISNL(c[6]); | |
numLines += ISNL(c[7]); | |
numWords += wasSp & (!ISSP(c[0])); | |
wasSp = ISSP(c[7]); | |
uint64_t spaces = | |
(ISSP(c[0]) << 0) | | |
(ISSP(c[1]) << 1) | | |
(ISSP(c[2]) << 2) | | |
(ISSP(c[3]) << 3) | | |
(ISSP(c[4]) << 4) | | |
(ISSP(c[5]) << 5) | | |
(ISSP(c[6]) << 6) | | |
(ISSP(c[7]) << 7); | |
numWords += lookup[spaces]; | |
} | |
for (; byte < len; ++byte) { | |
unsigned char c = text[byte]; | |
numWords += wasSp & (!ISSP(c)); | |
numLines += ISNL(c); | |
wasSp = ISSP(c); | |
} | |
printf("lines, words: %zu, %zu\n", numLines, numWords); | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// SIMD-1 /// | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
#include <emmintrin.h> | |
void count_simd1(char * text, size_t len) { | |
size_t numWords = 0; | |
size_t wasSp = 1; | |
const __m128i ff = _mm_set1_epi32(0xFF); | |
const __m128i one = _mm_set1_epi32(1); | |
const __m128i nine = _mm_set1_epi32(9); | |
const __m128i ten = _mm_set1_epi32(10); | |
const __m128i thirteen = _mm_set1_epi32(13); | |
const __m128i thirtyTwo = _mm_set1_epi32(32); | |
__m128i lineCounts = _mm_set1_epi32(0); | |
__m128i wordCounts = _mm_set1_epi32(0); | |
size_t byte = 0; | |
for (; byte < len - 15; byte += 16) { | |
uint8_t * c = (uint8_t *) (text + byte); | |
__m128i a = _mm_load_si128((__m128i *) c); | |
__m128i a1 = _mm_and_si128(a, ff); | |
__m128i a2 = _mm_and_si128(_mm_srli_epi32(a, 8), ff); | |
__m128i a3 = _mm_and_si128(_mm_srli_epi32(a, 16), ff); | |
__m128i a4 = _mm_and_si128(_mm_srli_epi32(a, 24), ff); | |
//sum newline counts | |
__m128i l1 = _mm_and_si128(_mm_cmpeq_epi32(a1, ten), one); | |
__m128i l2 = _mm_and_si128(_mm_cmpeq_epi32(a2, ten), one); | |
__m128i l3 = _mm_and_si128(_mm_cmpeq_epi32(a3, ten), one); | |
__m128i l4 = _mm_and_si128(_mm_cmpeq_epi32(a4, ten), one); | |
__m128i l = _mm_add_epi32(_mm_add_epi32(l1, l2), _mm_add_epi32(l3, l4)); | |
lineCounts = _mm_add_epi32(lineCounts, l); | |
//sum word counts | |
__m128i s1 = _mm_or_si128(_mm_cmpeq_epi32(a1, nine), _mm_cmpeq_epi32(a1, ten)); | |
__m128i s2 = _mm_or_si128(_mm_cmpeq_epi32(a2, nine), _mm_cmpeq_epi32(a2, ten)); | |
__m128i s3 = _mm_or_si128(_mm_cmpeq_epi32(a3, nine), _mm_cmpeq_epi32(a3, ten)); | |
__m128i s4 = _mm_or_si128(_mm_cmpeq_epi32(a4, nine), _mm_cmpeq_epi32(a4, ten)); | |
__m128i p1 = _mm_or_si128(_mm_cmpeq_epi32(a1, thirteen), _mm_cmpeq_epi32(a1, thirtyTwo)); | |
__m128i p2 = _mm_or_si128(_mm_cmpeq_epi32(a2, thirteen), _mm_cmpeq_epi32(a2, thirtyTwo)); | |
__m128i p3 = _mm_or_si128(_mm_cmpeq_epi32(a3, thirteen), _mm_cmpeq_epi32(a3, thirtyTwo)); | |
__m128i p4 = _mm_or_si128(_mm_cmpeq_epi32(a4, thirteen), _mm_cmpeq_epi32(a4, thirtyTwo)); | |
__m128i sp1 = _mm_or_si128(s1, p1); | |
__m128i sp2 = _mm_or_si128(s2, p2); | |
__m128i sp3 = _mm_or_si128(s3, p3); | |
__m128i sp4 = _mm_or_si128(s4, p4); | |
__m128i w1 = _mm_and_si128(_mm_andnot_si128(sp1, _mm_bslli_si128(sp4, 4)), one); | |
__m128i w2 = _mm_and_si128(_mm_andnot_si128(sp2, sp1), one); | |
__m128i w3 = _mm_and_si128(_mm_andnot_si128(sp3, sp2), one); | |
__m128i w4 = _mm_and_si128(_mm_andnot_si128(sp4, sp3), one); | |
__m128i w = _mm_add_epi32(_mm_add_epi32(w1, w2), _mm_add_epi32(w3, w4)); | |
wordCounts = _mm_add_epi32(wordCounts, w); | |
numWords += wasSp & (!ISSP(c[0])); | |
wasSp = ISSP(c[15]); | |
} | |
//sum SIMD lanes | |
uint32_t lineLanes[4] = {}; | |
_mm_store_si128((__m128i *) lineLanes, lineCounts); | |
size_t numLines = lineLanes[0] + lineLanes[1] + lineLanes[2] + lineLanes[3]; | |
uint32_t wordLanes[4] = {}; | |
_mm_store_si128((__m128i *) wordLanes, wordCounts); | |
numWords += wordLanes[0] + wordLanes[1] + wordLanes[2] + wordLanes[3]; | |
//cleanup | |
for (; byte < len; ++byte) { | |
unsigned char c = text[byte]; | |
numWords += wasSp & (!ISSP(c)); | |
numLines += ISNL(c); | |
wasSp = ISSP(c); | |
} | |
printf("lines, words: %zu, %zu\n", numLines, numWords); | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// SIMD-2 /// | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
void count_simd2(char * text, size_t len) { | |
size_t numWords = 0; | |
size_t numLines = 0; | |
size_t wasSp = 1; | |
const __m128i one = _mm_set1_epi8(1); | |
const __m128i nine = _mm_set1_epi8(9); | |
const __m128i ten = _mm_set1_epi8(10); | |
const __m128i thirteen = _mm_set1_epi8(13); | |
const __m128i thirtyTwo = _mm_set1_epi8(32); | |
size_t byte = 0; | |
while (byte < len - 4095) { | |
__m128i lineCounts = _mm_set1_epi8(0); | |
__m128i wordCounts = _mm_set1_epi8(0); | |
for (int i = 0; i < 256; ++i) { | |
uint8_t * c = (uint8_t *) (text + byte); | |
__m128i a = _mm_load_si128((__m128i *) c); | |
//sum line endings | |
lineCounts = _mm_add_epi8(lineCounts, _mm_and_si128(_mm_cmpeq_epi8(a, ten), one)); | |
//sum word boundaries | |
__m128i sp = _mm_or_si128(_mm_or_si128(_mm_cmpeq_epi8(a, nine), _mm_cmpeq_epi8(a, ten)), | |
_mm_or_si128(_mm_cmpeq_epi8(a, thirteen), _mm_cmpeq_epi8(a, thirtyTwo))); | |
__m128i w = _mm_and_si128(_mm_andnot_si128(sp, _mm_bslli_si128(sp, 1)), one); | |
wordCounts = _mm_add_epi8(wordCounts, w); | |
//handle word boundary edge cases | |
numWords += wasSp & (!ISSP(c[0])); | |
wasSp = ISSP(c[15]); | |
byte += 16; | |
} | |
//sum SIMD lanes | |
uint8_t lineLanes[16] = {}; | |
_mm_store_si128((__m128i *) lineLanes, lineCounts); | |
for (int i = 0; i < 16; ++i) numLines += lineLanes[i]; | |
uint8_t wordLanes[16] = {}; | |
_mm_store_si128((__m128i *) wordLanes, wordCounts); | |
for (int i = 0; i < 16; ++i) numWords += wordLanes[i]; | |
} | |
for (; byte < len; ++byte) { | |
unsigned char c = text[byte]; | |
numWords += wasSp & (!ISSP(c)); | |
numLines += ISNL(c); | |
wasSp = ISSP(c); | |
} | |
printf("lines, words: %zu, %zu\n", numLines, numWords); | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// SIMD-3 /// | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
#include <immintrin.h> | |
void count_simd3(char * text, size_t len) { | |
size_t numWords = 0; | |
size_t numLines = 0; | |
size_t wasSp = 1; | |
const __m256i one = _mm256_set1_epi8(1); | |
const __m256i nine = _mm256_set1_epi8(9); | |
const __m256i ten = _mm256_set1_epi8(10); | |
const __m256i thirteen = _mm256_set1_epi8(13); | |
const __m256i thirtyTwo = _mm256_set1_epi8(32); | |
size_t byte = 0; | |
while (byte < len - 8191) { | |
__m256i lineCounts = _mm256_set1_epi8(0); | |
__m256i wordCounts = _mm256_set1_epi8(0); | |
for (int i = 0; i < 256; ++i) { | |
uint8_t * c = (uint8_t *) (text + byte); | |
__m256i a = _mm256_load_si256((__m256i *) c); | |
//sum line endings | |
lineCounts = _mm256_add_epi8(lineCounts, _mm256_and_si256(_mm256_cmpeq_epi8(a, ten), one)); | |
//sum word boundaries | |
__m256i sp = _mm256_or_si256(_mm256_or_si256(_mm256_cmpeq_epi8(a, nine), _mm256_cmpeq_epi8(a, ten)), | |
_mm256_or_si256(_mm256_cmpeq_epi8(a, thirteen), _mm256_cmpeq_epi8(a, thirtyTwo))); | |
__m256i w = _mm256_and_si256(_mm256_andnot_si256(sp, _mm256_slli_si256(sp, 1)), one); | |
wordCounts = _mm256_add_epi8(wordCounts, w); | |
//handle word boundary edge cases | |
numWords += wasSp & (!ISSP(c[0])); | |
numWords += ISSP(c[15]) & (!ISSP(c[16])); | |
wasSp = ISSP(c[31]); | |
byte += 32; | |
} | |
//sum SIMD lanes | |
uint8_t lineLanes[32] = {}; | |
_mm256_store_si256((__m256i *) lineLanes, lineCounts); | |
for (int i = 0; i < 32; ++i) numLines += lineLanes[i]; | |
uint8_t wordLanes[32] = {}; | |
_mm256_store_si256((__m256i *) wordLanes, wordCounts); | |
for (int i = 0; i < 32; ++i) numWords += wordLanes[i]; | |
} | |
for (; byte < len; ++byte) { | |
unsigned char c = text[byte]; | |
numWords += wasSp & (!ISSP(c)); | |
numLines += ISNL(c); | |
wasSp = ISSP(c); | |
} | |
printf("lines, words: %zu, %zu\n", numLines, numWords); | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// MAIN (DRIVER) /// | |
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
#include <time.h> | |
void time_func(void (* func) (char *, size_t), char * text, size_t len) { | |
uint64_t ns = UINT64_MAX; | |
for (int i = 0; i < 5; ++i) { | |
timespec start, end; | |
assert(!clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &start)); | |
func(text, len); | |
assert(!clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &end)); | |
uint64_t ns1 = start.tv_sec * 1'000'000'000 + start.tv_nsec; | |
uint64_t ns2 = end.tv_sec * 1'000'000'000 + end.tv_nsec; | |
uint64_t diff = ns2 - ns1; | |
if (diff < ns) ns = diff; | |
} | |
printf("TIME: %.3fms\n", ns / 1'000'000.0); | |
fflush(stdout); | |
} | |
int main(int argc, const char ** argv) { | |
if (argc != 2) { printf("Usage: ./wc FILE\n"); return 1; } | |
char * text = read_entire_file(argv[1]); | |
size_t len = strlen(text); | |
precompute(); | |
time_func(count_basic, text, len); //.300 seconds | |
time_func(count_branchless, text, len); //.269 seconds | |
time_func(count_original, text, len); //.234 seconds | |
time_func(count_improved, text, len); //.184 seconds | |
time_func(count_simd1, text, len); //.130 seconds | |
time_func(count_simd2, text, len); //.035 seconds | |
time_func(count_simd3, text, len); //.026 seconds | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment