Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
A very fast version of nlerp with precision tweaks to make it match slerp, with SSE2 and AVX2 optimizations.
#include <stdio.h>
#include <math.h>
#include <immintrin.h>
#include <vector>
#include <type_traits>
#ifdef IACA
#include <iacaMarks.h>
#else
#define IACA_START
#define IACA_END
#endif
#ifndef _MM_ALIGN
#define _MM_ALIGN __attribute__((aligned(16)))
#endif
#define FORCEINLINE __attribute__((always_inline))
struct Q { float x, y, z, w; };
float dot(Q l, Q r)
{
return l.x * r.x + l.y * r.y + l.z * r.z + l.w * r.w;
}
Q unit(Q q)
{
float rs = 1 / sqrtf(dot(q, q));
return { q.x * rs, q.y * rs, q.z * rs, q.w * rs };
}
Q lerp(Q l, Q r, float lt, float rt)
{
return { l.x * lt + r.x * rt, l.y * lt + r.y * rt, l.z * lt + r.z * rt, l.w * lt + r.w * rt };
}
Q lerp(Q l, Q r, float t)
{
return lerp(l, r, 1 - t, t);
}
Q nlerp(Q l, Q r, float t)
{
float lt = 1 - t;
float rt = dot(l, r) > 0 ? t : -t;
Q result = unit(lerp(l, r, lt, rt));
return result;
}
Q slerp(Q l, Q r, float t)
{
float ca = dot(l, r);
float lt, rt;
if (fabsf(ca) < 0.99999f)
{
float a = acosf(ca);
float rsa = 1 / sinf(a);
lt = sinf((1 - t) * a) * rsa;
rt = sinf(t * a) * rsa;
}
else
{
lt = 1 - t;
rt = t;
}
Q result = lerp(l, r, lt, ca > 0 ? rt : -rt);
return result;
}
Q nslerp(Q l, Q r, float t)
{
float ca = dot(l, r);
float lt, rt;
if (fabsf(ca) < 0.99999f)
{
float a = acosf(ca);
lt = sinf((1 - t) * a);
rt = sinf(t * a);
}
else
{
lt = 1 - t;
rt = t;
}
Q result = unit(lerp(l, r, lt, ca > 0 ? rt : -rt));
return result;
}
Q fnlerp(Q l, Q r, float t)
{
float ca = dot(l, r);
float d = fabsf(ca);
float k = 0.931872f + d * (-1.25654f + d * 0.331442f);
float ot = t + t * (t - 0.5f) * (t - 1) * k;
float lt = 1 - ot;
float rt = ca > 0 ? ot : -ot;
Q result = unit(lerp(l, r, lt, rt));
return result;
}
Q onlerp(Q l, Q r, float t)
{
float ca = dot(l, r);
float d = fabsf(ca);
float A = 1.0904f + d * (-3.2452f + d * (3.55645f - d * 1.43519f));
float B = 0.848013f + d * (-1.06021f + d * 0.215638f);
float k = A * (t - 0.5f) * (t - 0.5f) + B;
float ot = t + t * (t - 0.5f) * (t - 1) * k;
float lt = 1 - ot;
float rt = ca > 0 ? ot : -ot;
Q result = unit(lerp(l, r, lt, rt));
return result;
}
void nlerp4(Q result[4], const Q l[4], const Q r[4], const float t_[4])
{
__m128 signMask = _mm_castsi128_ps(_mm_set1_epi32(0x80000000));
__m128 l0 = _mm_load_ps(&l[0].x);
__m128 l1 = _mm_load_ps(&l[1].x);
__m128 l2 = _mm_load_ps(&l[2].x);
__m128 l3 = _mm_load_ps(&l[3].x);
__m128 r0 = _mm_load_ps(&r[0].x);
__m128 r1 = _mm_load_ps(&r[1].x);
__m128 r2 = _mm_load_ps(&r[2].x);
__m128 r3 = _mm_load_ps(&r[3].x);
__m128 t = _mm_load_ps(t_);
_MM_TRANSPOSE4_PS(l0, l1, l2, l3);
_MM_TRANSPOSE4_PS(r0, r1, r2, r3);
__m128 ca = _mm_add_ps(_mm_add_ps(_mm_mul_ps(l0, r0), _mm_mul_ps(l1, r1)), _mm_add_ps(_mm_mul_ps(l2, r2), _mm_mul_ps(l3, r3)));
__m128 lt = _mm_sub_ps(_mm_set1_ps(1.f), t);
__m128 rt = _mm_xor_ps(t, _mm_and_ps(ca, signMask));
__m128 u0 = _mm_add_ps(_mm_mul_ps(l0, lt), _mm_mul_ps(r0, rt));
__m128 u1 = _mm_add_ps(_mm_mul_ps(l1, lt), _mm_mul_ps(r1, rt));
__m128 u2 = _mm_add_ps(_mm_mul_ps(l2, lt), _mm_mul_ps(r2, rt));
__m128 u3 = _mm_add_ps(_mm_mul_ps(l3, lt), _mm_mul_ps(r3, rt));
__m128 un = _mm_add_ps(_mm_add_ps(_mm_mul_ps(u0, u0), _mm_mul_ps(u1, u1)), _mm_add_ps(_mm_mul_ps(u2, u2), _mm_mul_ps(u3, u3)));
__m128 us0 = _mm_rsqrt_ps(un);
__m128 us1 = _mm_mul_ps(_mm_mul_ps(_mm_set1_ps(0.5f), us0), _mm_sub_ps(_mm_set1_ps(3.f), _mm_mul_ps(_mm_mul_ps(us0, us0), un)));
__m128 n0 = _mm_mul_ps(u0, us1);
__m128 n1 = _mm_mul_ps(u1, us1);
__m128 n2 = _mm_mul_ps(u2, us1);
__m128 n3 = _mm_mul_ps(u3, us1);
_MM_TRANSPOSE4_PS(n0, n1, n2, n3);
_mm_store_ps(&result[0].x, n0);
_mm_store_ps(&result[1].x, n1);
_mm_store_ps(&result[2].x, n2);
_mm_store_ps(&result[3].x, n3);
}
void fnlerp4(Q result[4], const Q l[4], const Q r[4], const float t_[4])
{
__m128 signMask = _mm_castsi128_ps(_mm_set1_epi32(0x80000000));
__m128 l0 = _mm_load_ps(&l[0].x);
__m128 l1 = _mm_load_ps(&l[1].x);
__m128 l2 = _mm_load_ps(&l[2].x);
__m128 l3 = _mm_load_ps(&l[3].x);
__m128 r0 = _mm_load_ps(&r[0].x);
__m128 r1 = _mm_load_ps(&r[1].x);
__m128 r2 = _mm_load_ps(&r[2].x);
__m128 r3 = _mm_load_ps(&r[3].x);
__m128 t = _mm_load_ps(t_);
_MM_TRANSPOSE4_PS(l0, l1, l2, l3);
_MM_TRANSPOSE4_PS(r0, r1, r2, r3);
__m128 ca = _mm_add_ps(_mm_add_ps(_mm_mul_ps(l0, r0), _mm_mul_ps(l1, r1)), _mm_add_ps(_mm_mul_ps(l2, r2), _mm_mul_ps(l3, r3)));
__m128 d = _mm_andnot_ps(signMask, ca);
__m128 k = _mm_add_ps(_mm_set1_ps(0.931872f), _mm_mul_ps(d, _mm_add_ps(_mm_set1_ps(-1.25654f), _mm_mul_ps(_mm_set1_ps(0.331442f), d))));
__m128 ot = _mm_add_ps(t, _mm_mul_ps(_mm_mul_ps(t, _mm_sub_ps(t, _mm_set1_ps(0.5f))), _mm_mul_ps(_mm_sub_ps(t, _mm_set1_ps(1.f)), k)));
__m128 lt = _mm_sub_ps(_mm_set1_ps(1.f), ot);
__m128 rt = _mm_xor_ps(ot, _mm_and_ps(ca, signMask));
__m128 u0 = _mm_add_ps(_mm_mul_ps(l0, lt), _mm_mul_ps(r0, rt));
__m128 u1 = _mm_add_ps(_mm_mul_ps(l1, lt), _mm_mul_ps(r1, rt));
__m128 u2 = _mm_add_ps(_mm_mul_ps(l2, lt), _mm_mul_ps(r2, rt));
__m128 u3 = _mm_add_ps(_mm_mul_ps(l3, lt), _mm_mul_ps(r3, rt));
__m128 un = _mm_add_ps(_mm_add_ps(_mm_mul_ps(u0, u0), _mm_mul_ps(u1, u1)), _mm_add_ps(_mm_mul_ps(u2, u2), _mm_mul_ps(u3, u3)));
__m128 us0 = _mm_rsqrt_ps(un);
__m128 us1 = _mm_mul_ps(_mm_mul_ps(_mm_set1_ps(0.5f), us0), _mm_sub_ps(_mm_set1_ps(3.f), _mm_mul_ps(_mm_mul_ps(us0, us0), un)));
__m128 n0 = _mm_mul_ps(u0, us1);
__m128 n1 = _mm_mul_ps(u1, us1);
__m128 n2 = _mm_mul_ps(u2, us1);
__m128 n3 = _mm_mul_ps(u3, us1);
_MM_TRANSPOSE4_PS(n0, n1, n2, n3);
_mm_store_ps(&result[0].x, n0);
_mm_store_ps(&result[1].x, n1);
_mm_store_ps(&result[2].x, n2);
_mm_store_ps(&result[3].x, n3);
}
void onlerp4(Q result[4], const Q l[4], const Q r[4], const float t_[4])
{
__m128 signMask = _mm_castsi128_ps(_mm_set1_epi32(0x80000000));
__m128 l0 = _mm_load_ps(&l[0].x);
__m128 l1 = _mm_load_ps(&l[1].x);
__m128 l2 = _mm_load_ps(&l[2].x);
__m128 l3 = _mm_load_ps(&l[3].x);
__m128 r0 = _mm_load_ps(&r[0].x);
__m128 r1 = _mm_load_ps(&r[1].x);
__m128 r2 = _mm_load_ps(&r[2].x);
__m128 r3 = _mm_load_ps(&r[3].x);
__m128 t = _mm_load_ps(t_);
_MM_TRANSPOSE4_PS(l0, l1, l2, l3);
_MM_TRANSPOSE4_PS(r0, r1, r2, r3);
__m128 ca = _mm_add_ps(_mm_add_ps(_mm_mul_ps(l0, r0), _mm_mul_ps(l1, r1)), _mm_add_ps(_mm_mul_ps(l2, r2), _mm_mul_ps(l3, r3)));
__m128 d = _mm_andnot_ps(signMask, ca);
__m128 th = _mm_sub_ps(t, _mm_set1_ps(0.5f));
__m128 d2 = _mm_mul_ps(d, d);
__m128 d3 = _mm_mul_ps(d2, d);
__m128 A = _mm_add_ps(_mm_set1_ps(1.0904f), _mm_mul_ps(d, _mm_add_ps(_mm_set1_ps(-3.2452f), _mm_mul_ps(d, _mm_add_ps(_mm_set1_ps(3.55645f), _mm_mul_ps(d, _mm_set1_ps(-1.43519f)))))));
__m128 B = _mm_add_ps(_mm_set1_ps(0.848013f), _mm_mul_ps(d, _mm_add_ps(_mm_set1_ps(-1.06021f), _mm_mul_ps(d, _mm_set1_ps(0.215638f)))));
__m128 k = _mm_add_ps(_mm_mul_ps(A, _mm_mul_ps(th, th)), B);
__m128 ot = _mm_add_ps(t, _mm_mul_ps(_mm_mul_ps(t, th), _mm_mul_ps(_mm_sub_ps(t, _mm_set1_ps(1.f)), k)));
__m128 lt = _mm_sub_ps(_mm_set1_ps(1.f), ot);
__m128 rt = _mm_xor_ps(ot, _mm_and_ps(ca, signMask));
__m128 u0 = _mm_add_ps(_mm_mul_ps(l0, lt), _mm_mul_ps(r0, rt));
__m128 u1 = _mm_add_ps(_mm_mul_ps(l1, lt), _mm_mul_ps(r1, rt));
__m128 u2 = _mm_add_ps(_mm_mul_ps(l2, lt), _mm_mul_ps(r2, rt));
__m128 u3 = _mm_add_ps(_mm_mul_ps(l3, lt), _mm_mul_ps(r3, rt));
__m128 un = _mm_add_ps(_mm_add_ps(_mm_mul_ps(u0, u0), _mm_mul_ps(u1, u1)), _mm_add_ps(_mm_mul_ps(u2, u2), _mm_mul_ps(u3, u3)));
__m128 us0 = _mm_rsqrt_ps(un);
__m128 us1 = _mm_mul_ps(_mm_mul_ps(_mm_set1_ps(0.5f), us0), _mm_sub_ps(_mm_set1_ps(3.f), _mm_mul_ps(_mm_mul_ps(us0, us0), un)));
__m128 n0 = _mm_mul_ps(u0, us1);
__m128 n1 = _mm_mul_ps(u1, us1);
__m128 n2 = _mm_mul_ps(u2, us1);
__m128 n3 = _mm_mul_ps(u3, us1);
_MM_TRANSPOSE4_PS(n0, n1, n2, n3);
_mm_store_ps(&result[0].x, n0);
_mm_store_ps(&result[1].x, n1);
_mm_store_ps(&result[2].x, n2);
_mm_store_ps(&result[3].x, n3);
}
#define _MM_TRANSPOSE8_LANE4_PS(row0, row1, row2, row3) \
do { \
__m256 __t0, __t1, __t2, __t3; \
__t0 = _mm256_unpacklo_ps(row0, row1); \
__t1 = _mm256_unpackhi_ps(row0, row1); \
__t2 = _mm256_unpacklo_ps(row2, row3); \
__t3 = _mm256_unpackhi_ps(row2, row3); \
row0 = _mm256_shuffle_ps(__t0, __t2, _MM_SHUFFLE(5, 4, 1, 0)); \
row1 = _mm256_shuffle_ps(__t0, __t2, _MM_SHUFFLE(7, 6, 3, 2)); \
row2 = _mm256_shuffle_ps(__t1, __t3, _MM_SHUFFLE(5, 4, 1, 0)); \
row3 = _mm256_shuffle_ps(__t1, __t3, _MM_SHUFFLE(7, 6, 3, 2)); \
} while (0)
void onlerp8(Q result[8], const Q l[8], const Q r[8], const float t_[8])
{
__m256 signMask = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000));
__m256 l0 = _mm256_load_ps(&l[0].x);
__m256 l1 = _mm256_load_ps(&l[2].x);
__m256 l2 = _mm256_load_ps(&l[4].x);
__m256 l3 = _mm256_load_ps(&l[6].x);
__m256 r0 = _mm256_load_ps(&r[0].x);
__m256 r1 = _mm256_load_ps(&r[2].x);
__m256 r2 = _mm256_load_ps(&r[4].x);
__m256 r3 = _mm256_load_ps(&r[6].x);
// lane transpose is swizzling the input quaternions like this:
// q0 q2 q4 q6 q1 q3 q5 q7
// so we need to transform t accordingly
__m256 tt = _mm256_load_ps(t_);
__m256 t = _mm256_permutevar8x32_ps(tt, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7));
_MM_TRANSPOSE8_LANE4_PS(l0, l1, l2, l3);
_MM_TRANSPOSE8_LANE4_PS(r0, r1, r2, r3);
__m256 ca = _mm256_add_ps(_mm256_add_ps(_mm256_mul_ps(l0, r0), _mm256_mul_ps(l1, r1)), _mm256_add_ps(_mm256_mul_ps(l2, r2), _mm256_mul_ps(l3, r3)));
__m256 d = _mm256_andnot_ps(signMask, ca);
__m256 th = _mm256_sub_ps(t, _mm256_set1_ps(0.5f));
__m256 d2 = _mm256_mul_ps(d, d);
__m256 d3 = _mm256_mul_ps(d2, d);
__m256 A = _mm256_add_ps(_mm256_set1_ps(1.0904f), _mm256_mul_ps(d, _mm256_add_ps(_mm256_set1_ps(-3.2452f), _mm256_mul_ps(d, _mm256_add_ps(_mm256_set1_ps(3.55645f), _mm256_mul_ps(d, _mm256_set1_ps(-1.43519f)))))));
__m256 B = _mm256_add_ps(_mm256_set1_ps(0.848013f), _mm256_mul_ps(d, _mm256_add_ps(_mm256_set1_ps(-1.06021f), _mm256_mul_ps(d, _mm256_set1_ps(0.215638f)))));
__m256 k = _mm256_add_ps(_mm256_mul_ps(A, _mm256_mul_ps(th, th)), B);
__m256 ot = _mm256_add_ps(t, _mm256_mul_ps(_mm256_mul_ps(t, th), _mm256_mul_ps(_mm256_sub_ps(t, _mm256_set1_ps(1.f)), k)));
__m256 lt = _mm256_sub_ps(_mm256_set1_ps(1.f), ot);
__m256 rt = _mm256_xor_ps(ot, _mm256_and_ps(ca, signMask));
__m256 u0 = _mm256_add_ps(_mm256_mul_ps(l0, lt), _mm256_mul_ps(r0, rt));
__m256 u1 = _mm256_add_ps(_mm256_mul_ps(l1, lt), _mm256_mul_ps(r1, rt));
__m256 u2 = _mm256_add_ps(_mm256_mul_ps(l2, lt), _mm256_mul_ps(r2, rt));
__m256 u3 = _mm256_add_ps(_mm256_mul_ps(l3, lt), _mm256_mul_ps(r3, rt));
__m256 un = _mm256_add_ps(_mm256_add_ps(_mm256_mul_ps(u0, u0), _mm256_mul_ps(u1, u1)), _mm256_add_ps(_mm256_mul_ps(u2, u2), _mm256_mul_ps(u3, u3)));
__m256 us0 = _mm256_rsqrt_ps(un);
__m256 us1 = _mm256_mul_ps(_mm256_mul_ps(_mm256_set1_ps(0.5f), us0), _mm256_sub_ps(_mm256_set1_ps(3.f), _mm256_mul_ps(_mm256_mul_ps(us0, us0), un)));
__m256 n0 = _mm256_mul_ps(u0, us1);
__m256 n1 = _mm256_mul_ps(u1, us1);
__m256 n2 = _mm256_mul_ps(u2, us1);
__m256 n3 = _mm256_mul_ps(u3, us1);
_MM_TRANSPOSE8_LANE4_PS(n0, n1, n2, n3);
_mm256_store_ps(&result[0].x, n0);
_mm256_store_ps(&result[2].x, n1);
_mm256_store_ps(&result[4].x, n2);
_mm256_store_ps(&result[6].x, n3);
}
Q axisangle(float x, float y, float z, float a)
{
float sa = sinf(a / 2);
float ca = cosf(a / 2);
return { x * sa, y * sa, z * sa, ca };
}
template <int N, typename F> struct lerparray
{
__attribute__((noinline))
static void run(Q* dest, const Q& l, const Q* r, const float* t, size_t size, F f)
{
Q ln[N];
for (int i = 0; i < N; ++i)
ln[i] = l;
for (size_t i = 0; i < size; i += N)
{
IACA_START
f(&dest[i], ln, &r[i], &t[i]);
IACA_END;
}
}
};
template <typename F> struct lerparray<1, F>
{
__attribute__((noinline))
static void run(Q* dest, const Q& l, const Q* r, const float* t, size_t size, F f)
{
for (size_t i = 0; i < size; ++i)
{
IACA_START
dest[i] = f(l, r[i], t[i]);
IACA_END;
}
}
};
template <int N, typename F>
void time(const char* name, F f)
{
std::vector<Q> rv;
std::vector<float> tv;
std::vector<float> av;
Q l = axisangle(1, 0, 0, 0);
for (double a = 0; a <= 3.1415926; a += 1e-3)
for (double t = 0; t <= 1; t += 1e-3)
{
rv.push_back(axisangle(1, 0, 0, a));
tv.push_back(t);
av.push_back(a * t);
}
// Round to 16 elements to match data between SIMD and non-SIMD versions
while (rv.size() % 16 != 0)
{
rv.push_back({0, 0, 0, 1});
tv.push_back(0);
av.push_back(0);
}
std::vector<Q> mv(rv.size());
lerparray<N, F>::run(mv.data(), l, rv.data(), tv.data(), rv.size(), f);
// Note: this measurement is inaccurate and requires precise clock() - won't work well on Windows.
// Don't trust the numbers produced by the timer too much.
clock_t start = clock();
lerparray<N, F>::run(mv.data(), l, rv.data(), tv.data(), rv.size(), f);
clock_t end = clock();
size_t maxi = 0;
double maxe = 0;
double sume = 0;
double nume = 0;
for (size_t i = 0; i < rv.size(); ++i)
{
Q m = mv[i];
// slerp is awesome; we need to fix it to not get NaN errors
if (m.w < 0) m.w = 0;
if (m.w > 1) m.w = 1;
double e = fabs(acos(m.w) * 2 - av[i]);
if (e > maxe)
{
maxe = e;
maxi = i;
}
sume += e;
nume += 1;
}
double avge = sume / nume;
printf("%s: %f us, %e max, %e avg\n", name, double(end - start) / CLOCKS_PER_SEC * 1e9 / rv.size(), maxe, avge);
}
#define TIME(fun) time<1>(#fun, [](Q l, Q r, float t) { return fun(l, r, t); })
#define TIME4(fun) time<4>(#fun, [](Q* d, const Q* l, const Q* r, const float* t) { fun(d, l, r, t); })
#define TIME8(fun) time<8>(#fun, [](Q* d, const Q* l, const Q* r, const float* t) { fun(d, l, r, t); })
int main()
{
#ifdef IACA
TIME8(onlerp8);
#else
TIME(slerp);
TIME(nslerp);
TIME(nlerp);
TIME(fnlerp);
TIME(onlerp);
TIME4(nlerp4);
TIME4(fnlerp4);
TIME4(onlerp4);
TIME8(onlerp8);
#endif
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment