Skip to content

Instantly share code, notes, and snippets.

@vurtun
Last active December 14, 2021 17:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vurtun/d08017d9b984dbee738da014c1203c2c to your computer and use it in GitHub Desktop.
Save vurtun/d08017d9b984dbee738da014c1203c2c to your computer and use it in GitHub Desktop.
// http://0x80.pl/notesen/2018-10-03-simd-index-of-min.html
#define cpy3(d,s) ((d)[0]=(s)[0],(d)[1]=(s)[1],(d)[2]=(s)[2])
#define dot3(a,b) ((a)[0]*(b)[0]+(a)[1]*(b)[1]+(a)[2]*(b)[2])
#if defined(__GNUC__) || defined(__clang__)
#define alignto(x) __attribute__((aligned(x)))
#elif defined(_MSC_VER)
#define alignto(x) __declspec(align(x))
#else
#define alignto(x) _Alignas(x)
#endif
#define sse_align alignto(16)
#define avx_align alignto(32)
#ifdef __x86_64__
#include <emmintrin.h>
#include <xmmintrin.h>
#include <smmintrin.h>
#define flt4 __m128
#define flt4_flt(a) _mm_set_ps1(a)
#define flt4_str(d,r) _mm_storeu_ps(((float*)(d)),r)
#define flt4_max(a,b) _mm_max_ps(a,b)
#define flt4_mul(a,b) _mm_mul_ps(a,b)
#define flt4_add(a,b) _mm_add_ps(a,b)
#define flt4_cmp_gt(a,b) _mm_castps_si128(_mm_cmpgt_ps(a,b))
#define flt4_zip_lo32(a,b) _mm_unpacklo_ps(a,b)
#define flt4_zip_hi32(a,b) _mm_unpackhi_ps(a,b)
#define flt4_zip_lo64(a,b) _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(a),_mm_castps_pd(b)))
#define flt4_zip_hi64(a,b) _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(a),_mm_castps_pd(b)))
#define int4 __m128i
#define int4_int(a) _mm_set1_epi32(a)
#define int4_set(x,y,z,w) _mm_setr_epi32(x,y,z,w)
#define int4_add(a,b) _mm_add_epi32(a,b)
#define int4_blend(a,b,m) _mm_blendv_epi8(a,b,m)
#define int4_str(d,r) _mm_storeu_si128((__m128i*)d, r)
static inline flt4
flt4_ld3(const float* val, float w) {
float sse_align v[4] = {val[0], val[1], val[2], w};
return _mm_load_ps(v);
}
#elif defined(__arm__) || defined(__aarch64__)
#include <arm_neon.h>
#define flt4 float32x4_t
#define flt4_flt(a) vdupq_n_f32(a)
#define flt4_str(d,r) vst1q_f32((float*)d, r)
#define flt4_max(a,b) vmaxnmq_f32(a,b)
#define flt4_mul(a,b) vmulq_f32(a,b)
#define flt4_add(a,b) vaddq_f32(a,b)
#define flt4_cmp_gt(a,b) vcgtq_f32(a,b)
#define flt4_zip_lo32(a,b) vzip1q_f32(a,b)
#define flt4_zip_hi32(a,b) vzip2q_f32(a,b)
#define flt4_zip_lo64(a,b) vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(a),vreinterpretq_f64_f32(b)))
#define flt4_zip_hi64(a,b) vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(a),vreinterpretq_f64_f32(b)))
#define int4 int32x4_t
#define int4_int(a) vdupq_n_s32(a)
#define int4_add(a,b) vaddq_s32(a,b)
#define int4_blend(a,b,msk) vbslq_s32(msk, b, a)
#define int4_str(d,r) vst1q_s32((int*)d,r)
static inline int4
int4_set(int i3, int i2, int i1, int i0) {
int sse_align v[4] = {i3, i2, i1, i0};
return vld1q_s32(v);
}
static inline flt4
flt4_ld3(const float *val, float w) {
float sse_align v[4] = {val[0], val[1], val[2], w};
return vld1q_f32(v);
}
#endif
extern int
polyhedron_support(float* restrict support,
const float* restrict dir,
const float* restrict verts, int cnt) {
int imax = 0;
float dmax = dot3(verts, dir);
const int4 inc = int4_int(4);
const flt4 dx = flt4_flt(dir[0]);
const flt4 dy = flt4_flt(dir[1]);
const flt4 dz = flt4_flt(dir[2]);
int4 idx = int4_set(0, 1, 2, 3);
flt4 dmax3 = flt4_flt(dmax);
int4 max_idx = int4_set(0, 0, 0, 0);
int i;
for (i = 0; i + 3 < cnt; i += 4) {
int at = i * 3;
flt4 a = flt4_ld3(verts + at + 0, 0.0f);
flt4 b = flt4_ld3(verts + at + 3, 0.0f);
flt4 c = flt4_ld3(verts + at + 6, 0.0f);
flt4 d = flt4_ld3(verts + at + 9, 0.0f);
flt4 ab_xxyy = flt4_zip_lo32(a, b);
flt4 cd_xxyy = flt4_zip_lo32(c, d);
flt4 ab_zzww = flt4_zip_hi32(a, b);
flt4 cd_zzww = flt4_zip_hi32(c, d);
flt4 x = flt4_zip_lo64(ab_xxyy, cd_xxyy);
flt4 y = flt4_zip_hi64(ab_xxyy, cd_xxyy);
flt4 z = flt4_zip_lo64(ab_zzww, cd_zzww);
flt4 xd = flt4_mul(dx, x);
flt4 yd = flt4_mul(dy, y);
flt4 zd = flt4_mul(dz, z);
flt4 dot = flt4_add(xd, flt4_add(yd, zd));
int4 gt = flt4_cmp_gt(dot, dmax3);
max_idx = int4_blend(max_idx, idx, gt);
dmax3 = flt4_max(dot, dmax3);
idx = int4_add(idx, inc);
}
for (; i < cnt; ++i) {
float dot = dot3(verts + i*3, dir);
if (dot < dmax) continue;
imax = i, dmax = dot;
}
{
float vals[4] = {0};
int indexes[4] = {0};
flt4_str(vals, dmax3);
int4_str(indexes, max_idx);
for (int i = 0; i < 4; i++) {
if (vals[i] > dmax) {
dmax = vals[i];
imax = indexes[i];
}
}
}
cpy3(support, &verts[imax*3]);
return imax;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment