Skip to content

Instantly share code, notes, and snippets.

@TOTBWF
Last active December 2, 2022 00:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TOTBWF/52969a650b33f6393f863de8aec75735 to your computer and use it in GitHub Desktop.
Save TOTBWF/52969a650b33f6393f863de8aec75735 to your computer and use it in GitHub Desktop.
Advent of Code 2022 Day 1 with excessive SIMD
#include <fcntl.h>
#include <immintrin.h>
#include <stdio.h>
#include <sys/mman.h>
#include <sys/stat.h>
uint32_t parse_4_digits(const __m128i input) {
const __m128i char_0 = _mm_set1_epi8('0');
// Normalize the '0' char to actually be 0x00.
const __m128i normalized = _mm_subs_epi8(input, char_0);
// The parsing algorithm proceeds by performing 2 multiplication + adjacent add operations.
// Our 4 digit string "1234" will get normalized to the vector [1,2,3,4]
// The first maddubs with mul_10 will yield [12,34], and the second 1234.
// Note that we need to convert to signed ints to be able to call _mm_cvtsi128_si32
const __m128i mul_10 = _mm_setr_epi8(10, 1, 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
const __m128i mul_100 = _mm_setr_epi16(100, 1, 0, 0, 0, 0, 0, 0);
const __m128i digits =_mm_madd_epi16(_mm_maddubs_epi16(normalized, mul_10), mul_100);
return _mm_cvtsi128_si32(digits);
}
uint32_t parse_5_digits(const __m128i input) {
const __m128i char_0 = _mm_set1_epi8('0');
// Normalize the '0' char to actually be 0x00.
const __m128i normalized = _mm_subs_epi8(input, char_0);
// We need to shuffle the 5th digit to be the LSB.
const __m128i shuffle_mask = _mm_setr_epi8(0, 1, 2, 3, 0x80, 0x80, 0x80, 4, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80);
// Same algorithm as the 4 digit case, making sure to fold in the 5th digit.
const __m128i shuffled = _mm_shuffle_epi8(normalized, shuffle_mask);
const __m128i mul_10 = _mm_setr_epi8(10, 1, 10, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0);
// We multiply the 4 top digits by an extra 10 here to account for the 5th digit.
const __m128i mul_100 = _mm_setr_epi16(1000, 10, 0, 1, 0, 0, 0, 0);
const __m128i digits_with_trailing = _mm_madd_epi16(_mm_maddubs_epi16(shuffled, mul_10), mul_100);
// Add together the upper 4 digits with the bottom 5th.
const __m128i digits = _mm_hadd_epi32(digits_with_trailing, digits_with_trailing);
return _mm_cvtsi128_si32(digits);
}
void pqueue_push(__m128i* pqueue, uint32_t x) {
const __m128i vec = _mm_set1_epi32(x);
const __m128i lt_mask = _mm_cmplt_epi32(*pqueue, vec);
const int num_less = _mm_popcnt_u32(_mm_movemask_epi8(lt_mask)) / 4;
__m128i queue_mask;
__m128i x_pos;
switch (num_less) {
case 0:
break;
case 1:
queue_mask = _mm_setr_epi8(0x80, 0x80, 0x80, 0x80, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
x_pos = _mm_setr_epi32(x, 0, 0, 0);
*pqueue = _mm_or_si128(x_pos, _mm_shuffle_epi8(*pqueue, queue_mask));
break;
case 2:
queue_mask = _mm_setr_epi8(4, 5, 6, 7, 0x80, 0x80, 0x80, 0x80, 8, 9, 10, 11, 12, 13, 14, 15);
x_pos = _mm_setr_epi32(0, x, 0, 0);
*pqueue = _mm_or_si128(x_pos, _mm_shuffle_epi8(*pqueue, queue_mask));
break;
case 3:
queue_mask = _mm_setr_epi8(4, 5, 6, 7, 8, 9, 10, 11, 0x80, 0x80, 0x80, 0x80, 12, 13, 14, 15);
x_pos = _mm_setr_epi32(0, 0, x, 0);
*pqueue = _mm_or_si128(x_pos, _mm_shuffle_epi8(*pqueue, queue_mask));
break;
case 4:
queue_mask = _mm_setr_epi8(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, 0x80, 0x80, 0x80);
x_pos = _mm_setr_epi32(0, 0, 0, x);
*pqueue = _mm_or_si128(x_pos, _mm_shuffle_epi8(*pqueue, queue_mask));
break;
}
}
int pqueue_top_3(__m128i pqueue) {
const __m128 mask = _mm_setr_epi32(0, 0xffffffff, 0xffffffff, 0xffffffff);
const __m128 pqueue_top = _mm_and_si128(mask, pqueue);
const __m128 a0 = _mm_hadd_epi32(pqueue_top, pqueue_top);
return _mm_cvtsi128_si32(_mm_hadd_epi32(a0, a0));
}
int pqueue_top(__m128i pqueue) {
return _mm_extract_epi32(pqueue, 3);
}
int main() {
struct stat statbuf;
int input_fd = open("day-1.txt", O_RDONLY);
fstat(input_fd, &statbuf);
char *input = mmap(NULL, statbuf.st_size, PROT_READ, MAP_PRIVATE, input_fd, 0);
uint32_t current = 0;
__m128i pqueue = _mm_set1_epi32(0);
while(*input) {
if (*input == '\n') {
pqueue_push(&pqueue, current);
current = 0;
input += 1;
} else if (input[4] == '\n') {
__m128i chunk = _mm_loadu_si128((__m128i*)input);
current += parse_4_digits(chunk);
input += 5;
} else {
__m128i chunk = _mm_loadu_si128((__m128i*)input);
current += parse_5_digits(chunk);
input += 6;
}
}
printf("Sum of Top : %d\nSum of Top 3: %d\n", pqueue_top(pqueue), pqueue_top_3(pqueue));
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment