Skip to content

Instantly share code, notes, and snippets.

@rygorous
Created November 9, 2015 07:37
Embed
What would you like to do?
Bytestream<->bitstream transpose
#include <stdio.h>
#include <emmintrin.h>
#include <intrin.h>
static void serial2parallel(__m128i const *bytes_in, __m128i *bits_out)
{
__m128i a,b,c,d,e,f,g,h, t;
// load
a = bytes_in[0];
b = bytes_in[1];
c = bytes_in[2];
d = bytes_in[3];
e = bytes_in[4];
f = bytes_in[5];
g = bytes_in[6];
h = bytes_in[7];
// Algorithm note:
//
// In the final result, we produce 8 SIMD registers worth of contents
// where register i contains bit i of every source byte.
//
// We start out with (writing SIMD registers as an array of bytes in
// memory order)
//
// reg0 = { x[0x00], x[0x01], x[0x02], ..., x[0x0f] }
// reg1 = { x[0x10], x[0x11], x[0x12], ..., x[0x1f] }
// ...
//
// Our final form has lane 1 of reg0 starting with bit 0 of byte 8,
// lane 2 starting with bit 0 of byte 16, and so forth.
// So we first go to a form where
//
// reg0 = { x[0x00], x[0x08], x[0x10], ..., x[0x78] }
// reg1 = { x[0x01], x[0x09], x[0x11], ..., x[0x79] }
//
// An even-odd merge (interleave) on a pair of SIMD registers
// effectively concatenates the two registers and then moves lane i
// in the pair to lane rotate_right(i, 1).
//
// Note that the transform we want amounts to concatenating all 8
// registers to one big array, then moving array index i to array
// index rotate_left(i, 3).
//
// We have 128 source bytes total, so instead of rotating left by
// 3, we can also rotate right by 4. This is nice for us since
// rotating right happens to be easier with the primitives we're given.
//
// After this, we have the bytes in the desired order, but each byte
// is still straight from the source. The rest can be handled using the
// usual chunky-to-planar "checkerboard merge" operations. We need to do
// nibble, 2-bit and 1-bit group swaps, which can be done in any order. We
// pick them up along the way whenever the current row permutation is
// convenient.
#define even_odd_merge(x,y) \
t = x; \
x = _mm_unpacklo_epi8(x, y); \
y = _mm_unpackhi_epi8(t, y)
#define even_odd_then_bit_exch(x,y, dist,mask) \
even_odd_merge(x, y); \
t = _mm_srli_epi16(x, dist); \
t = _mm_xor_si128(t, y); \
t = _mm_and_si128(t, _mm_set1_epi8((signed char)mask)); \
y = _mm_xor_si128(y, t); \
t = _mm_slli_epi16(t, dist); \
x = _mm_xor_si128(x, t)
// byte pass 0
even_odd_merge(a, e);
even_odd_merge(b, f);
even_odd_merge(c, g);
even_odd_merge(d, h);
// byte pass 1 and 4x4 bit pass
even_odd_then_bit_exch(a, c, 4, 0x0f);
even_odd_then_bit_exch(e, g, 4, 0x0f);
even_odd_then_bit_exch(b, d, 4, 0x0f);
even_odd_then_bit_exch(f, h, 4, 0x0f);
// byte pass 2 and 2x2 bit pass
even_odd_then_bit_exch(a, b, 2, 0x33);
even_odd_then_bit_exch(c, d, 2, 0x33);
even_odd_then_bit_exch(e, f, 2, 0x33);
even_odd_then_bit_exch(g, h, 2, 0x33);
// byte pass 3 and 1x1 bit pass
even_odd_then_bit_exch(a, e, 1, 0x55);
even_odd_then_bit_exch(b, f, 1, 0x55);
even_odd_then_bit_exch(c, g, 1, 0x55);
even_odd_then_bit_exch(d, h, 1, 0x55);
#undef even_odd_merge
#undef even_odd_then_bit_exch
// store
bits_out[0] = a;
bits_out[1] = e;
bits_out[2] = b;
bits_out[3] = f;
bits_out[4] = c;
bits_out[5] = g;
bits_out[6] = d;
bits_out[7] = h;
}
static void parallel2serial(__m128i const *bits_in, __m128i *bytes_out)
{
__m128i a,b,c,d,e,f,g,h, t;
// load
a = bits_in[0];
b = bits_in[1];
c = bits_in[2];
d = bits_in[3];
e = bits_in[4];
f = bits_in[5];
g = bits_in[6];
h = bits_in[7];
// Reverse permutation has our index rotates going in the right direction
// so it only needs byte even-odd merges.
#define even_odd_merge(x,y) \
t = x; \
x = _mm_unpacklo_epi8(x, y); \
y = _mm_unpackhi_epi8(t, y)
#define bit_exch_then_even_odd(x, y, dist,mask) \
t = _mm_srli_epi16(x, dist); \
t = _mm_xor_si128(t, y); \
t = _mm_and_si128(t, _mm_set1_epi8((signed char)mask)); \
y = _mm_xor_si128(y, t); \
t = _mm_slli_epi16(t, dist); \
x = _mm_xor_si128(x, t); \
even_odd_merge(x, y)
// 4x4 bit pass and byte pass 0
bit_exch_then_even_odd(a, e, 4, 0x0f);
bit_exch_then_even_odd(b, f, 4, 0x0f);
bit_exch_then_even_odd(c, g, 4, 0x0f);
bit_exch_then_even_odd(d, h, 4, 0x0f);
// 2x2 bit pass and byte pass 1
bit_exch_then_even_odd(a, c, 2, 0x33);
bit_exch_then_even_odd(e, g, 2, 0x33);
bit_exch_then_even_odd(b, d, 2, 0x33);
bit_exch_then_even_odd(f, h, 2, 0x33);
// 1x1 bit pass and byte pass 2
bit_exch_then_even_odd(a, b, 1, 0x55);
bit_exch_then_even_odd(c, d, 1, 0x55);
bit_exch_then_even_odd(e, f, 1, 0x55);
bit_exch_then_even_odd(g, h, 1, 0x55);
#undef even_odd_merge
#undef bit_exch_then_even_odd
// store
bytes_out[0] = a;
bytes_out[1] = b;
bytes_out[2] = c;
bytes_out[3] = d;
bytes_out[4] = e;
bytes_out[5] = f;
bytes_out[6] = g;
bytes_out[7] = h;
}
static void print128b(__m128i *arr)
{
for (int row = 0; row < 8; ++row)
{
printf("[%d]", row);
for (int col = 0; col < 16; ++col)
printf(" %02x", arr[row].m128i_u8[col]);
printf("\n");
}
}
static void print128bin(__m128i *arr)
{
for (int row = 0; row < 8; ++row)
{
printf("[%d]", row);
for (int col = 0; col < 16; ++col)
{
int v = arr[row].m128i_u8[col];
char buf[9];
buf[8] = 0;
// print LE bits left->right (!); makes more sense
// for the use case (bit masks for chars)
for (int bit = 0; bit < 8; ++bit)
buf[bit] = '0' + ((v >> bit) & 1);
printf(" %s", buf);
}
printf("\n");
}
}
typedef void transpose_func(__m128i const *indata, __m128i *outdata);
static double benchrun(transpose_func *func, __m128i const *indata, __m128i *outdata)
{
static int const nOuter = 4096;
static int const nInner = 4096;
int mintime = 0x7fffffff;
for (int outer = 0; outer < nOuter; ++outer)
{
long long start = __rdtsc();
for (int inner = 0; inner < nInner; ++inner)
func(indata, outdata);
int duration = (int) (__rdtsc() - start);
if (duration < mintime)
mintime = duration;
}
return 1.0 * mintime / nInner;
}
int main()
{
__m128i ser[8], par[8];
int i;
for (i = 0; i < 8*16; ++i)
ser[0].m128i_i8[i] = i + 0x40;
printf("in:\n");
print128b(ser);
double s2ptime = benchrun(serial2parallel, ser, par);
printf("\nout parallel: (~%.2f cycles each)\n", s2ptime);
print128bin(par);
double p2stime = benchrun(parallel2serial, par, ser);
printf("\nout serial: (~%.2f cycles each)\n", p2stime);
print128b(ser);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment