Skip to content

Instantly share code, notes, and snippets.

@notnullnotvoid
Last active February 8, 2020 16:31
Show Gist options
  • Save notnullnotvoid/a4ebf552badf99395bcbd08e0fd2ded6 to your computer and use it in GitHub Desktop.
Save notnullnotvoid/a4ebf552badf99395bcbd08e0fd2ded6 to your computer and use it in GitHub Desktop.
Optimizing simultaneous word and line counting for HMN
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
//NOTE: the correct behavior of this function is unfortunately not guaranteed by the standard
static char * read_entire_file(const char * filepath) {
FILE * f = fopen(filepath, "rb");
assert(f);
fseek(f, 0, SEEK_END);
long fsize = ftell(f);
fseek(f, 0, SEEK_SET); //same as rewind(f);
char * string = (char *) malloc(fsize + 1);
fread(string, fsize, 1, f);
fclose(f);
string[fsize] = 0;
return string;
}
static const unsigned char spacemap[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, };
static const unsigned char nlmap[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, };
#define ISSP(x) (spacemap[x])
#define ISNL(x) (nlmap[x])
static unsigned char lookup[256];
static void precompute() {
for (int i = 0; i < 256; i++) {
lookup[i] = 0;
int wassp = 0;
for (int j = 0; j < 8; j++) {
int issp = (i >> j) & 1;
if (wassp && !issp)
lookup[i]++;
wassp = issp;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// BASIC ///
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void count_basic(char * text, size_t len) {
size_t numWords = 0;
size_t numLines = 0;
size_t wasSp = 1;
for (size_t byte = 0; byte < len; ++byte) {
uint8_t c = text[byte];
if (wasSp) numWords += !ISSP(c);
numLines += ISNL(c);
wasSp = ISSP(c);
}
printf("lines, words: %zu, %zu\n", numLines, numWords);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// BASIC ///
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void count_branchless(char * text, size_t len) {
size_t numWords = 0;
size_t numLines = 0;
size_t wasSp = 1;
for (size_t byte = 0; byte < len; ++byte) {
uint8_t c = text[byte];
numWords += wasSp & (!ISSP(c));
numLines += ISNL(c);
wasSp = ISSP(c);
}
printf("lines, words: %zu, %zu\n", numLines, numWords);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// ORIGINAL ///
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void count_original(char * text, size_t len) {
size_t numWords = 0;
size_t numLines = 0;
size_t wasSp = 1;
size_t byte = 0;
for (; byte < len - 7; byte += 8) {
uint64_t x = *((uint64_t *) (text + byte));
uint64_t a0 = (x >> 0) & 0xff;
uint64_t a1 = (x >> 8) & 0xff;
uint64_t a2 = (x >> 16) & 0xff;
uint64_t a3 = (x >> 24) & 0xff;
uint64_t a4 = (x >> 32) & 0xff;
uint64_t a5 = (x >> 40) & 0xff;
uint64_t a6 = (x >> 48) & 0xff;
uint64_t a7 = (x >> 56) & 0xff;
numLines += ISNL(a0);
numLines += ISNL(a1);
numLines += ISNL(a2);
numLines += ISNL(a3);
numLines += ISNL(a4);
numLines += ISNL(a5);
numLines += ISNL(a6);
numLines += ISNL(a7);
if (wasSp) numWords += !ISSP(a0);
wasSp = ISSP(a7);
uint64_t spaces =
(ISSP(a0) << 0) |
(ISSP(a1) << 1) |
(ISSP(a2) << 2) |
(ISSP(a3) << 3) |
(ISSP(a4) << 4) |
(ISSP(a5) << 5) |
(ISSP(a6) << 6) |
(ISSP(a7) << 7);
numWords += lookup[spaces];
}
for (; byte < len; ++byte) {
unsigned char c = text[byte];
if (wasSp) numWords += !ISSP(c);
numLines += ISNL(c);
wasSp = ISSP(c);
}
printf("lines, words: %zu, %zu\n", numLines, numWords);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// IMPROVED ///
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void count_improved(char * text, size_t len) {
size_t numWords = 0;
size_t numLines = 0;
size_t wasSp = 1;
size_t byte = 0;
for (; byte < len - 7; byte += 8) {
uint8_t * c = (uint8_t *) (text + byte);
numLines += ISNL(c[0]);
numLines += ISNL(c[1]);
numLines += ISNL(c[2]);
numLines += ISNL(c[3]);
numLines += ISNL(c[4]);
numLines += ISNL(c[5]);
numLines += ISNL(c[6]);
numLines += ISNL(c[7]);
numWords += wasSp & (!ISSP(c[0]));
wasSp = ISSP(c[7]);
uint64_t spaces =
(ISSP(c[0]) << 0) |
(ISSP(c[1]) << 1) |
(ISSP(c[2]) << 2) |
(ISSP(c[3]) << 3) |
(ISSP(c[4]) << 4) |
(ISSP(c[5]) << 5) |
(ISSP(c[6]) << 6) |
(ISSP(c[7]) << 7);
numWords += lookup[spaces];
}
for (; byte < len; ++byte) {
unsigned char c = text[byte];
numWords += wasSp & (!ISSP(c));
numLines += ISNL(c);
wasSp = ISSP(c);
}
printf("lines, words: %zu, %zu\n", numLines, numWords);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// SIMD-1 ///
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <emmintrin.h>
void count_simd1(char * text, size_t len) {
size_t numWords = 0;
size_t wasSp = 1;
const __m128i ff = _mm_set1_epi32(0xFF);
const __m128i one = _mm_set1_epi32(1);
const __m128i nine = _mm_set1_epi32(9);
const __m128i ten = _mm_set1_epi32(10);
const __m128i thirteen = _mm_set1_epi32(13);
const __m128i thirtyTwo = _mm_set1_epi32(32);
__m128i lineCounts = _mm_set1_epi32(0);
__m128i wordCounts = _mm_set1_epi32(0);
size_t byte = 0;
for (; byte < len - 15; byte += 16) {
uint8_t * c = (uint8_t *) (text + byte);
__m128i a = _mm_load_si128((__m128i *) c);
__m128i a1 = _mm_and_si128(a, ff);
__m128i a2 = _mm_and_si128(_mm_srli_epi32(a, 8), ff);
__m128i a3 = _mm_and_si128(_mm_srli_epi32(a, 16), ff);
__m128i a4 = _mm_and_si128(_mm_srli_epi32(a, 24), ff);
//sum newline counts
__m128i l1 = _mm_and_si128(_mm_cmpeq_epi32(a1, ten), one);
__m128i l2 = _mm_and_si128(_mm_cmpeq_epi32(a2, ten), one);
__m128i l3 = _mm_and_si128(_mm_cmpeq_epi32(a3, ten), one);
__m128i l4 = _mm_and_si128(_mm_cmpeq_epi32(a4, ten), one);
__m128i l = _mm_add_epi32(_mm_add_epi32(l1, l2), _mm_add_epi32(l3, l4));
lineCounts = _mm_add_epi32(lineCounts, l);
//sum word counts
__m128i s1 = _mm_or_si128(_mm_cmpeq_epi32(a1, nine), _mm_cmpeq_epi32(a1, ten));
__m128i s2 = _mm_or_si128(_mm_cmpeq_epi32(a2, nine), _mm_cmpeq_epi32(a2, ten));
__m128i s3 = _mm_or_si128(_mm_cmpeq_epi32(a3, nine), _mm_cmpeq_epi32(a3, ten));
__m128i s4 = _mm_or_si128(_mm_cmpeq_epi32(a4, nine), _mm_cmpeq_epi32(a4, ten));
__m128i p1 = _mm_or_si128(_mm_cmpeq_epi32(a1, thirteen), _mm_cmpeq_epi32(a1, thirtyTwo));
__m128i p2 = _mm_or_si128(_mm_cmpeq_epi32(a2, thirteen), _mm_cmpeq_epi32(a2, thirtyTwo));
__m128i p3 = _mm_or_si128(_mm_cmpeq_epi32(a3, thirteen), _mm_cmpeq_epi32(a3, thirtyTwo));
__m128i p4 = _mm_or_si128(_mm_cmpeq_epi32(a4, thirteen), _mm_cmpeq_epi32(a4, thirtyTwo));
__m128i sp1 = _mm_or_si128(s1, p1);
__m128i sp2 = _mm_or_si128(s2, p2);
__m128i sp3 = _mm_or_si128(s3, p3);
__m128i sp4 = _mm_or_si128(s4, p4);
__m128i w1 = _mm_and_si128(_mm_andnot_si128(sp1, _mm_bslli_si128(sp4, 4)), one);
__m128i w2 = _mm_and_si128(_mm_andnot_si128(sp2, sp1), one);
__m128i w3 = _mm_and_si128(_mm_andnot_si128(sp3, sp2), one);
__m128i w4 = _mm_and_si128(_mm_andnot_si128(sp4, sp3), one);
__m128i w = _mm_add_epi32(_mm_add_epi32(w1, w2), _mm_add_epi32(w3, w4));
wordCounts = _mm_add_epi32(wordCounts, w);
numWords += wasSp & (!ISSP(c[0]));
wasSp = ISSP(c[15]);
}
//sum SIMD lanes
uint32_t lineLanes[4] = {};
_mm_store_si128((__m128i *) lineLanes, lineCounts);
size_t numLines = lineLanes[0] + lineLanes[1] + lineLanes[2] + lineLanes[3];
uint32_t wordLanes[4] = {};
_mm_store_si128((__m128i *) wordLanes, wordCounts);
numWords += wordLanes[0] + wordLanes[1] + wordLanes[2] + wordLanes[3];
//cleanup
for (; byte < len; ++byte) {
unsigned char c = text[byte];
numWords += wasSp & (!ISSP(c));
numLines += ISNL(c);
wasSp = ISSP(c);
}
printf("lines, words: %zu, %zu\n", numLines, numWords);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// SIMD-2 ///
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void count_simd2(char * text, size_t len) {
size_t numWords = 0;
size_t numLines = 0;
size_t wasSp = 1;
const __m128i one = _mm_set1_epi8(1);
const __m128i nine = _mm_set1_epi8(9);
const __m128i ten = _mm_set1_epi8(10);
const __m128i thirteen = _mm_set1_epi8(13);
const __m128i thirtyTwo = _mm_set1_epi8(32);
size_t byte = 0;
while (byte < len - 4095) {
__m128i lineCounts = _mm_set1_epi8(0);
__m128i wordCounts = _mm_set1_epi8(0);
for (int i = 0; i < 256; ++i) {
uint8_t * c = (uint8_t *) (text + byte);
__m128i a = _mm_load_si128((__m128i *) c);
//sum line endings
lineCounts = _mm_add_epi8(lineCounts, _mm_and_si128(_mm_cmpeq_epi8(a, ten), one));
//sum word boundaries
__m128i sp = _mm_or_si128(_mm_or_si128(_mm_cmpeq_epi8(a, nine), _mm_cmpeq_epi8(a, ten)),
_mm_or_si128(_mm_cmpeq_epi8(a, thirteen), _mm_cmpeq_epi8(a, thirtyTwo)));
__m128i w = _mm_and_si128(_mm_andnot_si128(sp, _mm_bslli_si128(sp, 1)), one);
wordCounts = _mm_add_epi8(wordCounts, w);
//handle word boundary edge cases
numWords += wasSp & (!ISSP(c[0]));
wasSp = ISSP(c[15]);
byte += 16;
}
//sum SIMD lanes
uint8_t lineLanes[16] = {};
_mm_store_si128((__m128i *) lineLanes, lineCounts);
for (int i = 0; i < 16; ++i) numLines += lineLanes[i];
uint8_t wordLanes[16] = {};
_mm_store_si128((__m128i *) wordLanes, wordCounts);
for (int i = 0; i < 16; ++i) numWords += wordLanes[i];
}
for (; byte < len; ++byte) {
unsigned char c = text[byte];
numWords += wasSp & (!ISSP(c));
numLines += ISNL(c);
wasSp = ISSP(c);
}
printf("lines, words: %zu, %zu\n", numLines, numWords);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// SIMD-3 ///
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <immintrin.h>
void count_simd3(char * text, size_t len) {
size_t numWords = 0;
size_t numLines = 0;
size_t wasSp = 1;
const __m256i one = _mm256_set1_epi8(1);
const __m256i nine = _mm256_set1_epi8(9);
const __m256i ten = _mm256_set1_epi8(10);
const __m256i thirteen = _mm256_set1_epi8(13);
const __m256i thirtyTwo = _mm256_set1_epi8(32);
size_t byte = 0;
while (byte < len - 8191) {
__m256i lineCounts = _mm256_set1_epi8(0);
__m256i wordCounts = _mm256_set1_epi8(0);
for (int i = 0; i < 256; ++i) {
uint8_t * c = (uint8_t *) (text + byte);
__m256i a = _mm256_load_si256((__m256i *) c);
//sum line endings
lineCounts = _mm256_add_epi8(lineCounts, _mm256_and_si256(_mm256_cmpeq_epi8(a, ten), one));
//sum word boundaries
__m256i sp = _mm256_or_si256(_mm256_or_si256(_mm256_cmpeq_epi8(a, nine), _mm256_cmpeq_epi8(a, ten)),
_mm256_or_si256(_mm256_cmpeq_epi8(a, thirteen), _mm256_cmpeq_epi8(a, thirtyTwo)));
__m256i w = _mm256_and_si256(_mm256_andnot_si256(sp, _mm256_slli_si256(sp, 1)), one);
wordCounts = _mm256_add_epi8(wordCounts, w);
//handle word boundary edge cases
numWords += wasSp & (!ISSP(c[0]));
numWords += ISSP(c[15]) & (!ISSP(c[16]));
wasSp = ISSP(c[31]);
byte += 32;
}
//sum SIMD lanes
uint8_t lineLanes[32] = {};
_mm256_store_si256((__m256i *) lineLanes, lineCounts);
for (int i = 0; i < 32; ++i) numLines += lineLanes[i];
uint8_t wordLanes[32] = {};
_mm256_store_si256((__m256i *) wordLanes, wordCounts);
for (int i = 0; i < 32; ++i) numWords += wordLanes[i];
}
for (; byte < len; ++byte) {
unsigned char c = text[byte];
numWords += wasSp & (!ISSP(c));
numLines += ISNL(c);
wasSp = ISSP(c);
}
printf("lines, words: %zu, %zu\n", numLines, numWords);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// MAIN (DRIVER) ///
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <time.h>
void time_func(void (* func) (char *, size_t), char * text, size_t len) {
uint64_t ns = UINT64_MAX;
for (int i = 0; i < 5; ++i) {
timespec start, end;
assert(!clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &start));
func(text, len);
assert(!clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &end));
uint64_t ns1 = start.tv_sec * 1'000'000'000 + start.tv_nsec;
uint64_t ns2 = end.tv_sec * 1'000'000'000 + end.tv_nsec;
uint64_t diff = ns2 - ns1;
if (diff < ns) ns = diff;
}
printf("TIME: %.3fms\n", ns / 1'000'000.0);
fflush(stdout);
}
int main(int argc, const char ** argv) {
if (argc != 2) { printf("Usage: ./wc FILE\n"); return 1; }
char * text = read_entire_file(argv[1]);
size_t len = strlen(text);
precompute();
time_func(count_basic, text, len); //.300 seconds
time_func(count_branchless, text, len); //.269 seconds
time_func(count_original, text, len); //.234 seconds
time_func(count_improved, text, len); //.184 seconds
time_func(count_simd1, text, len); //.130 seconds
time_func(count_simd2, text, len); //.035 seconds
time_func(count_simd3, text, len); //.026 seconds
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment