Created
December 12, 2020 22:09
-
-
Save Voltara/c67ab94694370f791ba182192eb7c585 to your computer and use it in GitHub Desktop.
Advent of Code 2020 Day 11 AVX2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// clang++ -O3 -march=native -std=c++17 day11-avx2.cpp | |
#include <cstdio> | |
#include <vector> | |
#include <array> | |
#include <cstring> | |
#include <chrono> | |
#include <tuple> | |
#include <x86intrin.h> | |
// How many times to iterate the solution for timing purposes | |
constexpr int ITERATIONS = 1000; | |
struct row { | |
std::array<__m256i, 2> R; | |
row() : R() { | |
} | |
row(const char *s, int dimx) : R() { | |
auto S = (uint64_t *) &R[0]; | |
for (int x = 0; x < dimx; x++) { | |
if (s[x] != 'L') continue; | |
S[x >> 4] |= uint64_t(1) << ((x & 15) << 2); | |
} | |
} | |
int count() const { | |
int n = 0; | |
auto S = (uint64_t *) &R[0]; | |
for (int i = 0; i < 8; i++) { | |
n += __builtin_popcountl(S[i]); | |
} | |
return n; | |
} | |
// Set 4-bit fields to 1 if zero, or 0 otherwise | |
row is_zero() const { | |
row r = ~*this; | |
for (auto &m : r.R) { | |
m = _mm256_and_si256(m, _mm256_srli_epi64(m, 1)); | |
m = _mm256_and_si256(m, _mm256_srli_epi64(m, 2)); | |
m = _mm256_and_si256(m, _mm256_set1_epi8(0x11)); | |
} | |
return r; | |
} | |
row is_small() const { | |
row r = ~*this; | |
for (auto &m : r.R) { | |
m = _mm256_sub_epi64(m, _mm256_set1_epi8(0x33)); | |
m = _mm256_srli_epi64(m, 3); | |
m = _mm256_and_si256(m, _mm256_set1_epi8(0x11)); | |
} | |
return r; | |
} | |
row operator+ (const row &o) const { row r = *this; return r += o; } | |
row operator& (const row &o) const { row r = *this; return r &= o; } | |
row operator| (const row &o) const { row r = *this; return r |= o; } | |
row operator<< (int s) const { row r = *this; return r <<= s; } | |
row operator>> (int s) const { row r = *this; return r >>= s; } | |
row operator ~ () const { | |
row r; | |
auto it = r.R.begin(); | |
for (auto &m : R) { | |
*it++ = _mm256_xor_si256(m, _mm256_set1_epi8(0xff)); | |
} | |
return r; | |
} | |
void shift_right_4() { | |
auto l = _mm256_srli_epi64(R[0], 4), cl = _mm256_slli_epi64(R[0], 60); | |
auto h = _mm256_srli_epi64(R[1], 4), ch = _mm256_slli_epi64(R[1], 60); | |
auto ml = _mm256_permute2x128_si256(ch, cl, _MM_SHUFFLE(0,0,0,3)); | |
auto mh = _mm256_permute2x128_si256(ch, ch, _MM_SHUFFLE(3,0,0,3)); | |
R[0] = _mm256_or_si256(l, _mm256_alignr_epi8(ml, cl, 8)); | |
R[1] = _mm256_or_si256(h, _mm256_alignr_epi8(mh, ch, 8)); | |
} | |
template<int N> void shift_left() { | |
const auto s = 16 - (N + 4) / 8; | |
auto m0 = _mm256_permute2x128_si256(R[0], R[0], _MM_SHUFFLE(0,0,3,0)); | |
auto m1 = _mm256_permute2x128_si256(R[0], R[1], _MM_SHUFFLE(0,2,0,1)); | |
R[0] = _mm256_alignr_epi8(R[0], m0, s); | |
R[1] = _mm256_alignr_epi8(R[1], m1, s); | |
if (N & 4) shift_right_4(); | |
} | |
template<int N> void shift_right() { | |
const auto s = N / 8; | |
auto m0 = _mm256_permute2x128_si256(R[1], R[0], _MM_SHUFFLE(0,0,0,3)); | |
auto m1 = _mm256_permute2x128_si256(R[1], R[1], _MM_SHUFFLE(3,0,0,3)); | |
R[0] = _mm256_alignr_epi8(m0, R[0], s); | |
R[1] = _mm256_alignr_epi8(m1, R[1], s); | |
if (N & 4) shift_right_4(); | |
} | |
row& operator<<= (int sl) { | |
switch (sl) { | |
case 4: shift_left< 4>(); break; | |
case 8: shift_left< 8>(); break; | |
case 12: shift_left<12>(); break; | |
case 16: shift_left<16>(); break; | |
case 20: shift_left<20>(); break; | |
case 24: shift_left<24>(); break; | |
case 28: shift_left<28>(); break; | |
case 32: shift_left<32>(); break; | |
case 36: shift_left<36>(); break; | |
case 40: shift_left<40>(); break; | |
case 44: shift_left<44>(); break; | |
case 48: shift_left<48>(); break; | |
case 52: shift_left<52>(); break; | |
case 56: shift_left<56>(); break; | |
case 60: shift_left<60>(); break; | |
case 64: shift_left<64>(); break; | |
case 68: shift_left<68>(); break; | |
case 72: shift_left<72>(); break; | |
default: abort(); | |
} | |
return *this; | |
} | |
row& operator>>= (int sr) { | |
switch (sr) { | |
case 4: shift_right< 4>(); break; | |
case 8: shift_right< 8>(); break; | |
case 12: shift_right<12>(); break; | |
case 16: shift_right<16>(); break; | |
case 20: shift_right<20>(); break; | |
case 24: shift_right<24>(); break; | |
case 28: shift_right<28>(); break; | |
case 32: shift_right<32>(); break; | |
case 36: shift_right<36>(); break; | |
case 40: shift_right<40>(); break; | |
case 44: shift_right<44>(); break; | |
case 48: shift_right<48>(); break; | |
case 52: shift_right<52>(); break; | |
case 56: shift_right<56>(); break; | |
case 60: shift_right<60>(); break; | |
case 64: shift_right<64>(); break; | |
case 68: shift_right<68>(); break; | |
case 72: shift_right<72>(); break; | |
default: abort(); | |
} | |
return *this; | |
} | |
row& operator += (const row &o) { | |
R[0] = _mm256_add_epi64(R[0], o.R[0]); | |
R[1] = _mm256_add_epi64(R[1], o.R[1]); | |
return *this; | |
} | |
row& operator &= (const row &o) { | |
R[0] = _mm256_and_si256(R[0], o.R[0]); | |
R[1] = _mm256_and_si256(R[1], o.R[1]); | |
return *this; | |
} | |
row& operator |= (const row &o) { | |
R[0] = _mm256_or_si256(R[0], o.R[0]); | |
R[1] = _mm256_or_si256(R[1], o.R[1]); | |
return *this; | |
} | |
bool operator != (const row &o) const { | |
return ~_mm256_movemask_epi8(_mm256_cmpeq_epi64(R[0], o.R[0])) || | |
~_mm256_movemask_epi8(_mm256_cmpeq_epi64(R[1], o.R[1])); | |
} | |
bool operator ! () const { | |
uint32_t nz = 0; | |
nz |= _mm256_movemask_epi8(_mm256_cmpgt_epi64(R[0], _mm256_setzero_si256())); | |
nz |= _mm256_movemask_epi8(_mm256_cmpgt_epi64(R[1], _mm256_setzero_si256())); | |
return !nz; | |
} | |
}; | |
std::pair<int,int> solve(char *in) { | |
char *nl = strchr(in, '\n'); | |
if (!nl || nl == in) abort(); | |
int dimx = nl - in; | |
if (dimx > 124) abort(); | |
std::vector<row> G; | |
for (int y = 0; in[dimx]; y++) { | |
if (in[dimx] != '\n') abort(); | |
G.emplace_back(in, dimx); | |
in += dimx + 1; | |
} | |
// Neighbor counts | |
std::vector<row> N(G.size()); | |
// Seat locations & inverses | |
std::vector<row> Seat = G, SeatI; | |
for (auto &r : G) { | |
SeatI.push_back(~r); | |
} | |
// Find horizontal shifts for each row, precompute masks | |
std::vector<int> Shift; | |
std::vector<row> ShiftMaskL, ShiftMaskR; | |
for (int i = 0; i < G.size(); i++) { | |
auto gl = G[i], sl = gl, gr = gl, sr = gl; | |
int shift = 0; | |
while (!!sl) { | |
shift += 4; | |
sl >>= 4, sr <<= 4; | |
auto ml = gl & sl; | |
if (!ml) continue; | |
auto mr = gr & sr; | |
sl &= ~ml, gl &= ~ml; | |
sr &= ~mr, gr &= ~mr; | |
Shift.push_back(shift); | |
ShiftMaskL.push_back(ml); | |
ShiftMaskR.push_back(mr); | |
shift = 0; | |
} | |
Shift.push_back(0); | |
ShiftMaskL.emplace_back(); | |
ShiftMaskR.emplace_back(); | |
} | |
// Update grid based on neighbor counts | |
auto update = [&]() { | |
bool diff = false; | |
for (int i = 0; i < G.size(); i++) { | |
auto t = (N[i].is_small() & G[i]) | | |
(N[i].is_zero() & Seat[i]); | |
diff |= (t != G[i]); | |
G[i] = t; | |
} | |
return diff; | |
}; | |
// Count all occupied seats | |
auto count = [&]() { | |
int n = 0; | |
for (auto &r : G) n += r.count(); | |
return n; | |
}; | |
// Part 1 | |
for (bool diff = true; diff; diff = update()) { | |
N[0] = G[0] + (G[0] << 4) + (G[0] >> 4); | |
row t = N[0]; | |
for (int i = 1; i < G.size(); i++) { | |
N[i] = t; | |
t = G[i] + (G[i] << 4) + (G[i] >> 4); | |
N[i] += t, N[i-1] += t; | |
} | |
} | |
int part1 = count(); | |
// Part 2 | |
G = Seat; | |
for (bool diff = true; diff; diff = update()) { | |
std::fill(N.begin(), N.end(), row{}); | |
// Propagate distant neighbors vertically/diagonally | |
auto vertical = [&](int start, int end, int incr) { | |
row UL, U, UR; | |
for (int i = start; i != end; i += incr) { | |
UL = UL >>= 4, UR <<= 4; | |
N[i] += UL + U + UR; | |
UL = G[i] | (UL & SeatI[i]); | |
U = G[i] | (U & SeatI[i]); | |
UR = G[i] | (UR & SeatI[i]); | |
} | |
}; | |
vertical(0, G.size(), 1); // Top-down | |
vertical(G.size() - 1, -1, -1); // Bottom-up | |
// Propagate horizontally | |
for (int i = 0, sidx = 0; i < G.size(); i++, sidx++) { | |
row L = G[i], R = G[i]; | |
for (auto s = Shift[sidx]; s; s = Shift[++sidx]) { | |
L >>= s, R <<= s; | |
N[i] += (L & ShiftMaskL[sidx]); | |
N[i] += (R & ShiftMaskR[sidx]); | |
} | |
} | |
} | |
int part2 = count(); | |
return { part1, part2 }; | |
} | |
int main(int argc, char **argv) { | |
int part1 = 0, part2 = 0; | |
FILE *fp = (argc > 1) ? fopen(argv[1], "r") : stdin; | |
if (!fp) perror(argv[1]), abort(); | |
std::array<char, 16384> buf = { }; | |
fread(&buf[0], buf.size() - 128, 1, fp); | |
if (fp != stdin) fclose(fp); | |
auto t0 = std::chrono::steady_clock::now(); | |
for (int i = 0; i < ITERATIONS; i++) { | |
std::tie(part1, part2) = solve(&buf[0]); | |
} | |
auto elapsed = std::chrono::steady_clock::now() - t0; | |
double ns = std::chrono::duration_cast<std::chrono::nanoseconds>(elapsed).count(); | |
ns /= ITERATIONS; | |
double us = ns * 1e-3; | |
printf("Part 1: %d\n", part1); | |
printf("Part 2: %d\n", part2); | |
printf("Average time (%d iterations): %.f μs\n", ITERATIONS, us); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment