Skip to content

Instantly share code, notes, and snippets.

@Voltara
Created December 4, 2018 07:06
Show Gist options
  • Save Voltara/cf417ac4d2e31f85c3cd9b43ccb9338b to your computer and use it in GitHub Desktop.
Save Voltara/cf417ac4d2e31f85c3cd9b43ccb9338b to your computer and use it in GitHub Desktop.
Advent of Code Day 2 Part 2
// clang++ -O3 -march=native -std=c++17 hashsimd.cpp
#include <cstdio>
#include <vector>
#include <array>
#include <algorithm>
#include <x86intrin.h>
#include <inttypes.h>
static uint64_t hash(const __m256i &m);
int main() {
std::vector<__m256i> v;
// Assumes all input lines are exactly 26 characters long
char buf[32] = {};
while (fgets(buf, sizeof(buf), stdin)) {
buf[26] = 0; // For safety
v.emplace_back(*reinterpret_cast<__m256i *>(buf));
}
// Hash two disjoint sections of each input word, and create a
// vector of the hashes and array indices
std::vector<uint64_t> hv;
hv.reserve(v.size());
for (size_t i = 0; i < v.size(); i++) {
uint64_t h0 = hash(_mm256_and_si256(v[i],
_mm256_set1_epi16(0x00ff)));
uint64_t h1 = hash(_mm256_and_si256(v[i],
_mm256_set1_epi16(0xff00)));
hv.emplace_back((h0 << 32) | i);
hv.emplace_back((h1 << 32) | i);
}
// Sort 'hv' by hash value using a 3-pass histogram sort
std::vector<uint64_t> hv_tmp(hv.size());
std::array<uint32_t, 2048> hist;
for (int shift : { 32, 43, 54 }) {
uint32_t total = 0;
hist.fill(0);
for (auto h : hv) hist[(h >> shift) & 0x3ff]++;
for (auto &n : hist) total += std::exchange(n, total);
for (auto h : hv) hv_tmp[hist[(h >> shift) & 0x3ff]++] = h;
hv.swap(hv_tmp);
}
// Collect indexes for hashes that appear more than once
std::vector<uint32_t> idxv;
for (size_t i = 1; i < hv.size(); i++) {
if ((hv[i - 1] >> 32) == (hv[i] >> 32)) {
idxv.emplace_back(hv[i - 1]); // TODO deduplicate
idxv.emplace_back(hv[i]);
}
}
// Sort and deduplicate the index list; this also ensures the
// candidates remain in the same relative order
std::sort(idxv.begin(), idxv.end());
idxv.erase(std::unique(idxv.begin(), idxv.end()), idxv.end());
// Remove all non-candidates from the input list
decltype(v) vv;
size_t new_size = 0;
for (auto idx : idxv) {
vv.emplace_back(v[idx]);
}
for (auto x = vv.begin(); x != vv.end(); x++) {
for (auto y = x + 1; y != vv.end(); y++) {
// Parallel compare 8x32 vector, convert to bitmask of differences
auto cmp_ne = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(*x, *y));
// Get lowest bit in difference mask
auto diff1 = _blsi_u32(cmp_ne);
// Was it the only difference?
if (diff1 == cmp_ne) {
auto s = reinterpret_cast<char *>(&*x);
// Get the index of the difference
auto diff_idx = _bit_scan_forward(diff1);
// Output all but the difference
printf("Part 2: %.*s%s\n", diff_idx, s, s + diff_idx + 1);
return 0;
}
}
}
printf("Part 2 solution not found\n");
return 0;
}
uint64_t hash(const __m256i &m) {
auto u64 = reinterpret_cast<const uint64_t *>(&m);
uint64_t h = 0;
for (int i = 0; i < 4; i++) {
h = _mm_crc32_u64(h, u64[i]);
}
return h;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment