Skip to content

Instantly share code, notes, and snippets.

@hugbug
Created September 23, 2017 19:09
Show Gist options
  • Save hugbug/fd7d95d53e3a2ca4aafb0f811d929bfc to your computer and use it in GitHub Desktop.
Save hugbug/fd7d95d53e3a2ca4aafb0f811d929bfc to your computer and use it in GitHub Desktop.
yEnc SSE Decoder Test
#include <vector>
#include <iostream>
static size_t do_decode_scalar(const unsigned char* src, unsigned char* dest, size_t len, char* state, bool isRaw);
static size_t do_decode_sse(const unsigned char* src, unsigned char* dest, size_t len, char* state, bool isRaw);
using decode_func = size_t(*)(const unsigned char* src, unsigned char* dest, size_t len, char* state, bool isRaw);
typedef struct
{
const char* source;
const char* expected;
} test_pair;
std::vector<test_pair> tests = {
{"7c 8b 9c 4b 44 31 2a 84 98 9d 3b 2b 37 2a 2a 2a 2a 2a 2a 2a dd b5 9e 4c bb 78 2a b4 29 69 30 2a 2a 8a ed 2c 8a 04 14 98 1e 8c 2e 73 3e 5a 4b 2a 4a 4a 2a 2a 2a 2a 2a 2a 2b 2a 2a 2a 79 9a 8f 98 6b 7e 80 57 8c 9f 93 96 8e 57 8c 99 a2 57 76 93 98 9f a2 57 8e 93 9d 95 5b 58 a0 8e 93 2a 1a 7d e6 a2 66 66 66 4a 79 9c 8b 8d 96 8f 4a 80 77 4a 80 93 9c 9e 9f 8b 96 6c 99 a2 4a 6e 93 9d 95 4a ",
"52 61 72 21 1a 07 00 5a 6e 73 11 01 0d 00 00 00 00 00 00 00 b3 8b 74 22 91 4e 00 8a ff 3f 06 00 00 60 c3 02 60 da ea 6e f4 62 04 49 14 30 21 00 20 20 00 00 00 00 00 00 01 00 00 00 4f 70 65 6e 41 54 56 2d 62 75 69 6c 64 2d 62 6f 78 2d 4c 69 6e 75 78 2d 64 69 73 6b 31 2e 76 64 69 00 f0 53 bc 78 3c 3c 3c 20 4f 72 61 63 6c 65 20 56 4d 20 56 69 72 74 75 61 6c 42 6f 78 20 44 69 73 6b 20 "},
{"0f 1a b1 a4 0c d4 15 2a 61 47 c8 bf bf e4 d3 e9 e2 b2 1f e0 99 1d 79 9a 38 26 c0 8b 3d 40 42 cf c6 f9 85 34 8d 9c f2 55 ce 16 ec 4d 38 29 3d 4d 22 d8 bb cc ce 2a 91 c9 93 87 6f 0f fb 5b 2c d7 90 3c 22 4a ac a7 1a 57 1a bb 6b 64 23 e0 87 8f b2 3d 7d 94 30 c5 eb 2f cb e5 78 35 8e bc d0 0b 57 15 58 69 e3 9d fc f3 da 6b c1 07 3d 4d d2 6a 60 6f 43 a4 3d 4a 81 dc b7 ca 04 8a c1 f6 8d b8 ",
"e5 f0 87 7a e2 aa eb 00 37 1d 9e 95 95 ba a9 bf b8 88 f5 b6 6f f3 4f 70 0e fc 96 61 d6 18 a5 9c cf 5b 0a 63 72 c8 2b a4 ec c2 23 0e ff e3 f8 ae 91 a2 a4 00 67 9f 69 5d 45 e5 d1 31 02 ad 66 12 f8 20 82 7d f0 2d f0 91 41 3a f9 b6 5d 65 88 13 6a 06 9b c1 05 a1 bb 4e 0b 64 92 a6 e1 2d eb 2e 3f b9 73 d2 c9 b0 41 97 dd e3 a8 40 36 45 19 7a e0 57 b2 8d a0 da 60 97 cc 63 8e "}
};
std::vector<unsigned char> unhex(const char* hexstr)
{
std::vector<unsigned char> result;
for (const char* p = hexstr; *p; p += 3)
{
char *endptr;
long val = strtol(p, &endptr, 16);
result.push_back((unsigned char)val);
}
return result;
}
void print_vector(const char* title, std::vector<unsigned char> vec)
{
printf("%s:\n", title);
for (unsigned char ch : vec)
{
printf("%02x ", ch);
}
printf("\n");
}
void test(decode_func func)
{
int num = 0;
for (test_pair pair : tests)
{
std::vector<unsigned char> source = unhex(pair.source);
std::vector<unsigned char> expected = unhex(pair.expected);
std::vector<unsigned char> result;
result.resize(source.size());
size_t res_len = func(source.data(), result.data(), source.size(), nullptr, false);
result.resize(res_len);
bool ok = result == expected;
printf("\nTest %i: %s\n", ++num, ok ? "OK" : "FAILURE");
if (!ok)
{
print_vector("Source", source);
print_vector("Expected", expected);
print_vector("Result", result);
}
}
}
int main()
{
printf("Testing scalar\n");
test((decode_func)do_decode_scalar);
printf("\n");
printf("Testing sse\n");
test((decode_func)do_decode_sse);
}
/***************************************************/
// from "yencode.cc":
// https://github.com/animetosho/node-yencode/blob/3b97c8b04d73c245a1b9eb488d412fe804b57656/yencode.cc
#include <stdlib.h>
// MSVC compatibility
#if (defined(_M_IX86_FP) && _M_IX86_FP == 2) || defined(_M_X64)
#define __SSE2__ 1
#define __SSSE3__ 1
//#define __SSE4_1__ 1
#if defined(_MSC_VER) && _MSC_VER >= 1600
#define X86_PCLMULQDQ_CRC 1
#endif
#endif
#ifdef _MSC_VER
#define __BYTE_ORDER__ 1234
#define __ORDER_BIG_ENDIAN__ 4321
#include <intrin.h>
#endif
#if defined(__x86_64__) || defined(__i386__)
#if !defined(X86_PCLMULQDQ_CRC) && defined(__PCLMUL__) && defined(__SSSE3__) && defined(__SSE4_1__)
#define X86_PCLMULQDQ_CRC 1
#endif
#endif
// combine two 8-bit ints into a 16-bit one
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define UINT16_PACK(a, b) (((a) << 8) | (b))
#define UINT32_PACK(a, b, c, d) (((a) << 24) | ((b) << 16) | ((c) << 8) | (d))
#define UINT32_16_PACK(a, b) (((a) << 16) | (b))
#else
#define UINT16_PACK(a, b) ((a) | ((b) << 8))
#define UINT32_PACK(a, b, c, d) ((a) | ((b) << 8) | ((c) << 16) | ((d) << 24))
#define UINT32_16_PACK(a, b) ((a) | ((b) << 16))
#endif
#ifdef __SSE2__
#include <emmintrin.h>
#define XMM_SIZE 16 /*== (signed int)sizeof(__m128i)*/
#ifdef _MSC_VER
#define ALIGN_32(v) __declspec(align(32)) v
#else
#define ALIGN_32(v) v __attribute__((aligned(32)))
#endif
#ifdef __SSSE3__
#include <tmmintrin.h>
#endif
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif
#ifdef __POPCNT__
#include <nmmintrin.h>
#endif
/*
#ifdef __AVX2__
#include <immintrin.h>
#endif
*/
#if defined(__tune_core2__) || defined(__tune_atom__)
/* on older Intel CPUs, plus first gen Atom, it is faster to store XMM registers in half */
# define STOREU_XMM(dest, xmm) \
_mm_storel_epi64((__m128i*)(dest), xmm); \
_mm_storeh_pi(((__m64*)(dest) +1), _mm_castsi128_ps(xmm))
#else
# define STOREU_XMM(dest, xmm) \
_mm_storeu_si128((__m128i*)(dest), xmm)
#endif
#endif
// slightly faster version which improves the worst case scenario significantly; since worst case doesn't happen often, overall speedup is relatively minor
// requires PSHUFB (SSSE3) instruction, but will use POPCNT (SSE4.2 (or AMD's ABM, but Phenom doesn't support SSSE3 so doesn't matter)) if available (these only seem to give minor speedups, so considered optional)
#ifdef __SSSE3__
ALIGN_32(__m128i _shufLUT[258]); // +2 for underflow guard entry
__m128i* shufLUT = _shufLUT+2;
ALIGN_32(__m128i shufMixLUT[256]);
#ifndef __POPCNT__
// table from http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetTable
static const unsigned char BitsSetTable256[256] =
{
# define B2(n) n, n+1, n+1, n+2
# define B4(n) B2(n), B2(n+1), B2(n+1), B2(n+2)
# define B6(n) B4(n), B4(n+1), B4(n+1), B4(n+2)
B6(0), B6(1), B6(1), B6(2)
#undef B2
#undef B4
#undef B6
};
#endif
#endif
// state var: refers to the previous state - only used for incremental processing
// 0: previous characters are `\r\n` OR there is no previous character
// 1: previous character is `=`
// 2: previous character is `\r`
// 3: previous character is none of the above
static size_t do_decode_scalar_raw(const unsigned char* src, unsigned char* dest, size_t len, char* state) {
unsigned char *p = dest; // destination pointer
unsigned long i = 0; // input position
unsigned char c; // input character
if(len < 1) return 0;
if(state) switch(*state) {
case 1:
c = src[i];
*p++ = c - 42 - 64;
i++;
if(c == '\r' && i < len) {
*state = 2;
// fall through to case 2
} else {
*state = 3;
break;
}
case 2:
if(src[i] != '\n') break;
i++;
*state = 0; // now `\r\n`
if(len <= i) return 0;
case 0:
// skip past first dot
if(src[i] == '.') i++;
} else // treat as *state == 0
if(src[i] == '.') i++;
for(; i + 2 < len; i++) {
c = src[i];
switch(c) {
case '\r':
// skip past \r\n. sequences
//i += (*(uint16_t*)(src + i + 1) == UINT16_PACK('\n', '.')) << 1;
if(*(uint16_t*)(src + i + 1) == UINT16_PACK('\n', '.'))
i += 2;
case '\n':
continue;
case '=':
c = src[i+1];
*p++ = c - 42 - 64;
i += (c != '\r'); // if we have a \r, reprocess character to deal with \r\n. case
continue;
default:
*p++ = c - 42;
}
}
if(state) *state = 3;
if(i+1 < len) { // 2nd last char
c = src[i];
switch(c) {
case '\r':
if(state && src[i+1] == '\n') {
*state = 0;
return p - dest;
}
case '\n':
break;
case '=':
c = src[i+1];
*p++ = c - 42 - 64;
i += (c != '\r');
break;
default:
*p++ = c - 42;
}
i++;
}
// do final char; we process this separately to prevent an overflow if the final char is '='
if(i < len) {
c = src[i];
if(c != '\n' && c != '\r' && c != '=') {
*p++ = c - 42;
} else if(state) {
if(c == '=') *state = 1;
else if(c == '\r') *state = 2;
else *state = 3;
}
}
return p - dest;
}
static size_t do_decode_scalar(const unsigned char* src, unsigned char* dest, size_t len, char* state, bool isRaw) {
if(isRaw) return do_decode_scalar_raw(src, dest, len, state);
unsigned char *p = dest; // destination pointer
unsigned long i = 0; // input position
unsigned char c; // input character
if(len < 1) return 0;
if(state && *state == 1) {
*p++ = src[i] - 42 - 64;
i++;
*state = 3;
}
/*for(i = 0; i < len - 1; i++) {
c = src[i];
if(c == '\n' || c == '\r') continue;
unsigned char isEquals = (c == '=');
i += isEquals;
*p++ = src[i] - (42 + (isEquals << 6));
}*/
for(; i+1 < len; i++) {
c = src[i];
switch(c) {
case '\n': case '\r': continue;
case '=':
i++;
c = src[i] - 64;
}
*p++ = c - 42;
}
if(state) *state = 3;
// do final char; we process this separately to prevent an overflow if the final char is '='
if(i < len) {
c = src[i];
if(c != '\n' && c != '\r' && c != '=') {
*p++ = c - 42;
} else
if(state) *state = (c == '=' ? 1 : 3);
}
return p - dest;
}
#ifdef __SSE2__
#ifdef __SSSE3__
ALIGN_32(__m64 unshufLUT[256]);
uint8_t eqFixLUT[256];
ALIGN_32(__m64 eqSubLUT[256]);
ALIGN_32(static const uint8_t _pshufb_combine_table[272]) = {
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x80,
0x00,0x01,0x02,0x03,0x04,0x05,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x80,0x80,
0x00,0x01,0x02,0x03,0x04,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x80,0x80,0x80,
0x00,0x01,0x02,0x03,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x80,0x80,0x80,0x80,
0x00,0x01,0x02,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x80,0x80,0x80,0x80,0x80,
0x00,0x01,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x80,0x80,0x80,0x80,0x80,0x80,
0x00,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x80,0x80,0x80,0x80,0x80,0x80,0x80,
0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x80,0x80,0x80,0x80,0x80,0x80,0x80,0x80,
};
static const __m128i* pshufb_combine_table = (const __m128i*)_pshufb_combine_table;
#endif
static size_t do_decode_sse(const unsigned char* src, unsigned char* dest, size_t len, char* state, bool isRaw) {
if(len <= sizeof(__m128i)*2) return do_decode_scalar(src, dest, len, state, isRaw);
unsigned char *p = dest; // destination pointer
unsigned long i = 0; // input position
unsigned char escFirst = 0; // input character; first char needs escaping
unsigned int nextMask = 0;
char tState = 0;
char* pState = state ? state : &tState;
if((uintptr_t)src & ((sizeof(__m128i)-1))) {
// find source memory alignment
unsigned char* aSrc = (unsigned char*)(((uintptr_t)src + (sizeof(__m128i)-1)) & ~(sizeof(__m128i)-1));
i = aSrc - src;
p += do_decode_scalar(src, dest, i, pState, isRaw);
}
// handle finicky case of \r\n. straddled across initial boundary
if(*pState == 0 && i+1 < len && src[i] == '.')
nextMask = 1;
else if(*pState == 2 && i+2 < len && *(uint16_t*)(src + i) == UINT16_PACK('\n','.'))
nextMask = 2;
escFirst = *pState == 1;
// our algorithm may perform an aligned load on the next part, of which we consider 2 bytes (for \r\n. sequence checking)
for(; i + (sizeof(__m128i)+1) < len; i += sizeof(__m128i)) {
__m128i data = _mm_load_si128((__m128i *)(src + i));
// search for special chars
__m128i cmpSkip = _mm_or_si128(
_mm_cmpeq_epi8(data, _mm_set1_epi8('\r')),
_mm_cmpeq_epi8(data, _mm_set1_epi8('\n'))
),
cmpEq = _mm_cmpeq_epi8(data, _mm_set1_epi8('=')),
cmp = _mm_or_si128(
cmpSkip,
cmpEq
);
unsigned int mask = _mm_movemask_epi8(cmp); // not the most accurate mask if we have invalid sequences; we fix this up later
__m128i oData;
if(escFirst) { // TODO: should be possible to eliminate branch by storing vectors adjacently
// first byte needs escaping due to preceeding = in last loop iteration
oData = _mm_sub_epi8(data, _mm_set_epi8(42,42,42,42,42,42,42,42,42,42,42,42,42,42,42,42+64));
} else {
oData = _mm_sub_epi8(data, _mm_set1_epi8(42));
}
mask &= ~escFirst;
mask |= nextMask;
if (mask != 0) {
// a spec compliant encoder should never generate sequences: ==, =\n and =\r, but we'll handle them to be spec compliant
// the yEnc specification requires any character following = to be unescaped, not skipped over, so we'll deal with that
#define LOAD_HALVES(a, b) _mm_castps_si128(_mm_loadh_pi( \
_mm_castsi128_ps(_mm_loadl_epi64((__m128i*)(a))), \
(b) \
))
// firstly, resolve invalid sequences of = to deal with cases like '===='
unsigned int maskEq = _mm_movemask_epi8(cmpEq);
unsigned int tmp = eqFixLUT[(maskEq&0xff) & ~escFirst];
maskEq = (eqFixLUT[(maskEq>>8) & ~((tmp&0x80)>>7)] << 8) | tmp;
// next, eliminate anything following a `=` from the special char mask; this eliminates cases of `=\r` so that they aren't removed
mask &= ~(maskEq << 1);
// unescape chars following `=`
__m128i eqVec = LOAD_HALVES(eqSubLUT + (maskEq&0xff), eqSubLUT + (maskEq>>8));
oData = _mm_sub_epi8(
oData,
_mm_slli_si128(eqVec, 1)
//_mm_and_si128(_mm_slli_si128(cmpEq, 1), _mm_set1_epi8(64))
);
// handle \r\n. sequences
// RFC3977 requires the first dot on a line to be stripped, due to dot-stuffing
if(isRaw) {
#ifdef __SSSE3__
# define ALIGNR _mm_alignr_epi8
#else
# define ALIGNR(a, b, i) _mm_or_si128(_mm_slli_si128(a, sizeof(__m128i)-(i)), _mm_srli_si128(b, i))
#endif
__m128i nextData = _mm_load_si128((__m128i *)(src + i) + 1);
// find instances of \r\n
__m128i tmpData = ALIGNR(nextData, data, 1);
__m128i cmp1 = _mm_cmpeq_epi16(data, _mm_set1_epi16(0x0a0d));
__m128i cmp2 = _mm_cmpeq_epi16(tmpData, _mm_set1_epi16(0x0a0d));
// trim matches to just the \n
cmp1 = _mm_and_si128(cmp1, _mm_set1_epi16(0xff00));
cmp2 = _mm_and_si128(cmp2, _mm_set1_epi16(0xff00));
// merge the two comparisons
cmp1 = _mm_or_si128(_mm_srli_si128(cmp1, 1), cmp2);
// then check if there's a . after any of these instances
tmpData = ALIGNR(nextData, data, 2);
tmpData = _mm_cmpeq_epi8(tmpData, _mm_set1_epi8('.'));
// grab bit-mask of matched . characters and OR with mask
unsigned int killDots = _mm_movemask_epi8(_mm_and_si128(tmpData, cmp1));
mask |= (killDots << 2) & 0xffff;
nextMask = killDots >> 14;
#undef ALIGNR
}
escFirst = (maskEq >> (sizeof(__m128i)-1));
// all that's left is to 'compress' the data (skip over masked chars)
#ifdef __SSSE3__
# ifdef __POPCNT__
unsigned char skipped = _mm_popcnt_u32(mask & 0xff);
# else
unsigned char skipped = BitsSetTable256[mask & 0xff];
# endif
// lookup compress masks and shuffle
// load up two halves
__m128i shuf = LOAD_HALVES(unshufLUT + (mask&0xff), unshufLUT + (mask>>8));
// offset upper half by 8
shuf = _mm_add_epi8(shuf, _mm_set_epi32(0x08080808, 0x08080808, 0, 0));
// shift down upper half into lower
shuf = _mm_shuffle_epi8(shuf, _mm_load_si128(pshufb_combine_table + skipped));
// shuffle data
oData = _mm_shuffle_epi8(oData, shuf);
STOREU_XMM(p, oData);
// increment output position
# ifdef __POPCNT__
p += XMM_SIZE - _mm_popcnt_u32(mask);
# else
p += XMM_SIZE - skipped - BitsSetTable256[mask >> 8];
# endif
#else
ALIGN_32(uint32_t mmTmp[4]);
_mm_store_si128((__m128i*)mmTmp, oData);
for(int j=0; j<4; j++) {
if(mask & 0xf) {
unsigned char* pMmTmp = (unsigned char*)(mmTmp + j);
*p = pMmTmp[0];
p += (mask & 1);
*p = pMmTmp[1];
p += (mask & 2) != 0;
*p = pMmTmp[2];
p += (mask & 4) != 0;
*p = pMmTmp[3];
p += (mask & 8) != 0;
} else {
*(uint32_t*)p = mmTmp[j];
p += 4;
}
mask >>= 4;
}
#endif
#undef LOAD_HALVES
} else {
STOREU_XMM(p, oData);
p += XMM_SIZE;
escFirst = 0;
nextMask = 0;
}
}
if(escFirst) *pState = 1; // escape next character
else if(nextMask == 1) *pState = 0; // next character is '.', where previous two were \r\n
else if(nextMask == 2) *pState = 2; // next characters are '\n.', previous is \r
else *pState = 3;
// end alignment
if(i < len) {
p += do_decode_scalar(src + i, p, len - i, pState, isRaw);
}
return p - dest;
}
#endif
@hugbug
Copy link
Author

hugbug commented Sep 23, 2017

Results

Testing scalar

Test 1: OK

Test 2: OK

Testing sse

Test 1: OK

Test 2: FAILURE
Source:
0f 1a b1 a4 0c d4 15 2a 61 47 c8 bf bf e4 d3 e9 e2 b2 1f e0 99 1d 79 9a 38 26 c0 8b 3d 40 42 cf c6 f9 85 34 8d 9c f2 55 ce 16 ec 4d 38 29 3d 4d 22 d8 bb cc ce 2a 91 c9 93 87 6f 0f fb 5b 2c d7 90 3c 22 4a ac a7 1a 57 1a bb 6b 64 23 e0 87 8f b2 3d 7d 94 30 c5 eb 2f cb e5 78 35 8e bc d0 0b 57 15 58 69 e3 9d fc f3 da 6b c1 07 3d 4d d2 6a 60 6f 43 a4 3d 4a 81 dc b7 ca 04 8a c1 f6 8d b8
Expected:
e5 f0 87 7a e2 aa eb 00 37 1d 9e 95 95 ba a9 bf b8 88 f5 b6 6f f3 4f 70 0e fc 96 61 d6 18 a5 9c cf 5b 0a 63 72 c8 2b a4 ec c2 23 0e ff e3 f8 ae 91 a2 a4 00 67 9f 69 5d 45 e5 d1 31 02 ad 66 12 f8 20 82 7d f0 2d f0 91 41 3a f9 b6 5d 65 88 13 6a 06 9b c1 05 a1 bb 4e 0b 64 92 a6 e1 2d eb 2e 3f b9 73 d2 c9 b0 41 97 dd e3 a8 40 36 45 19 7a e0 57 b2 8d a0 da 60 97 cc 63 8e
Result:
e5 f0 87 7a e2 aa eb 00 37 1d 9e 95 95 ba a9 bf b8 b8 b8 b8 b8 b8 b8 b8 0e 0e 0e 0e 0e 0e 0e 9c 9c 9c 9c 9c 9c 9c 9c a4 a4 a4 a4 a4 a4 a4 f8 ae 91 a2 a4 00 67 9f 69 5d 45 e5 d1 31 02 ad 66 12 f8 20 82 7d f0 2d f0 91 41 3a f9 b6 5d 65 88 88 88 88 88 88 88 a1 a1 a1 a1 a1 a1 a1 a1 2d 2d 2d 2d 2d 2d 2d 2d b0 b0 b0 b0 b0 b0 b0 36 45 19 7a e0 57 b2 8d a0 da 60 97 cc 63 8e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment