Skip to content

Instantly share code, notes, and snippets.

@jart
Last active May 22, 2024 22:07
Show Gist options
  • Save jart/5f5bddd513e338da70b0a81b62f9a116 to your computer and use it in GitHub Desktop.
Save jart/5f5bddd513e338da70b0a81b62f9a116 to your computer and use it in GitHub Desktop.
glibc vectorized expf() versus my faster intrinsic vectorized expf() for x86-64
/* Function expf vectorized with AVX2.
Copyright (C) 2014-2024 Free Software Foundation, Inc.
This file is part of the GNU C Library.
The GNU C Library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
The GNU C Library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with the GNU C Library; if not, see
<https://www.gnu.org/licenses/>. */
#define __sInvLn2 0
#define __sShifter 64
#define __sLn2hi 128
#define __sLn2lo 192
#define __iBias 256
#define __sPC0 320
#define __sPC1 384
#define __sPC2 448
#define __sPC3 512
#define __sPC4 576
#define __sPC5 640
#define __iAbsMask 704
#define __iDomainRange 768
.macro float_vector offset value
.if .-__svml_sexp_data != \offset
.err
.endif
.rept 16
.long \value
.endr
.endm
.section .rodata, "a"
.align 64
/* Data table for vector implementations of function expf.
The table may contain polynomial, reduction, lookup coefficients and
other coefficients obtained through different methods of research and
experimental work. */
.globl __svml_sexp_data
__svml_sexp_data:
/* Range reduction coefficients:
* log(2) inverted */
float_vector __sInvLn2 0x3fb8aa3b
/* right shifter constant */
float_vector __sShifter 0x4b400000
/* log(2) high part */
float_vector __sLn2hi 0x3f317200
/* log(2) low part */
float_vector __sLn2lo 0x35bfbe8e
/* bias */
float_vector __iBias 0x0000007f
/* Polynomial coefficients:
* Here we approximate 2^x on [-0.5, 0.5] */
float_vector __sPC0 0x3f800000
float_vector __sPC1 0x3f7ffffe
float_vector __sPC2 0x3effff34
float_vector __sPC3 0x3e2aacac
float_vector __sPC4 0x3d2b8392
float_vector __sPC5 0x3c07d9fe
/* absolute value mask */
float_vector __iAbsMask 0x7fffffff
/* working domain range */
float_vector __iDomainRange 0x42aeac4f
.type __svml_sexp_data,@object
.size __svml_sexp_data,.-__svml_sexp_data
.previous
.section .text.avx2, "ax", @progbits
libmvec_expf_avx2:
/*
ALGORITHM DESCRIPTION:
Argument representation:
M = rint(X*2^k/ln2) = 2^k*N+j
X = M*ln2/2^k + r = N*ln2 + ln2*(j/2^k) + r
then -ln2/2^(k+1) < r < ln2/2^(k+1)
Alternatively:
M = trunc(X*2^k/ln2)
then 0 < r < ln2/2^k
Result calculation:
exp(X) = exp(N*ln2 + ln2*(j/2^k) + r)
= 2^N * 2^(j/2^k) * exp(r)
2^N is calculated by bit manipulation
2^(j/2^k) is computed from table lookup
exp(r) is approximated by polynomial
The table lookup is skipped if k = 0.
For low accuracy approximation, exp(r) ~ 1 or 1+r. */
pushq %rbp
movq %rsp, %rbp
andq $-64, %rsp
subq $448, %rsp
lea __svml_sexp_data(%rip), %rax
vmovaps %ymm0, %ymm2
vmovups __sInvLn2(%rax), %ymm7
vmovups __sShifter(%rax), %ymm4
vmovups __sLn2hi(%rax), %ymm3
vmovups __sPC5(%rax), %ymm1
/* m = x*2^k/ln2 + shifter */
vfmadd213ps %ymm4, %ymm2, %ymm7
/* n = m - shifter = rint(x*2^k/ln2) */
vsubps %ymm4, %ymm7, %ymm0
vpaddd __iBias(%rax), %ymm7, %ymm4
/* remove sign of x by "and" operation */
vandps __iAbsMask(%rax), %ymm2, %ymm5
/* compare against threshold */
vpcmpgtd __iDomainRange(%rax), %ymm5, %ymm6
/* r = x-n*ln2_hi/2^k */
vmovaps %ymm2, %ymm5
vfnmadd231ps %ymm0, %ymm3, %ymm5
/* r = r-n*ln2_lo/2^k = x - n*ln2/2^k */
vfnmadd132ps __sLn2lo(%rax), %ymm5, %ymm0
/* c5*r+c4 */
vfmadd213ps __sPC4(%rax), %ymm0, %ymm1
/* (c5*r+c4)*r+c3 */
vfmadd213ps __sPC3(%rax), %ymm0, %ymm1
/* ((c5*r+c4)*r+c3)*r+c2 */
vfmadd213ps __sPC2(%rax), %ymm0, %ymm1
/* (((c5*r+c4)*r+c3)*r+c2)*r+c1 */
vfmadd213ps __sPC1(%rax), %ymm0, %ymm1
/* exp(r) = ((((c5*r+c4)*r+c3)*r+c2)*r+c1)*r+c0 */
vfmadd213ps __sPC0(%rax), %ymm0, %ymm1
/* set mask for overflow/underflow */
vmovmskps %ymm6, %ecx
/* compute 2^N with "shift" */
vpslld $23, %ymm4, %ymm6
/* 2^N*exp(r) */
vmulps %ymm1, %ymm6, %ymm0
testl %ecx, %ecx
jne .LBL_1_3
.LBL_1_2:
movq %rbp, %rsp
popq %rbp
ret
.LBL_1_3:
vmovups %ymm2, 320(%rsp)
vmovups %ymm0, 384(%rsp)
je .LBL_1_2
xorb %dl, %dl
xorl %eax, %eax
vmovups %ymm8, 224(%rsp)
vmovups %ymm9, 192(%rsp)
vmovups %ymm10, 160(%rsp)
vmovups %ymm11, 128(%rsp)
vmovups %ymm12, 96(%rsp)
vmovups %ymm13, 64(%rsp)
vmovups %ymm14, 32(%rsp)
vmovups %ymm15, (%rsp)
movq %rsi, 264(%rsp)
movq %rdi, 256(%rsp)
movq %r12, 296(%rsp)
movb %dl, %r12b
movq %r13, 288(%rsp)
movl %ecx, %r13d
movq %r14, 280(%rsp)
movl %eax, %r14d
movq %r15, 272(%rsp)
.LBL_1_6:
btl %r14d, %r13d
jc .LBL_1_12
.LBL_1_7:
lea 1(%r14), %esi
btl %esi, %r13d
jc .LBL_1_10
.LBL_1_8:
incb %r12b
addl $2, %r14d
cmpb $16, %r12b
jb .LBL_1_6
vmovups 224(%rsp), %ymm8
vmovups 192(%rsp), %ymm9
vmovups 160(%rsp), %ymm10
vmovups 128(%rsp), %ymm11
vmovups 96(%rsp), %ymm12
vmovups 64(%rsp), %ymm13
vmovups 32(%rsp), %ymm14
vmovups (%rsp), %ymm15
vmovups 384(%rsp), %ymm0
movq 264(%rsp), %rsi
movq 256(%rsp), %rdi
movq 296(%rsp), %r12
movq 288(%rsp), %r13
movq 280(%rsp), %r14
movq 272(%rsp), %r15
jmp .LBL_1_2
.LBL_1_10:
movzbl %r12b, %r15d
vmovss 324(%rsp,%r15,8), %xmm0
vzeroupper
call expf
vmovss %xmm0, 388(%rsp,%r15,8)
jmp .LBL_1_8
.LBL_1_12:
movzbl %r12b, %r15d
vmovss 320(%rsp,%r15,8), %xmm0
vzeroupper
call expf
vmovss %xmm0, 384(%rsp,%r15,8)
jmp .LBL_1_7
.size libmvec_expf_avx2,.-libmvec_expf_avx2
.type libmvec_expf_avx2,@function
.globl libmvec_expf_avx2
// faster vectorized expf() for x86-64
// based on arm limited optimized routine
// written by justine tunney for llamafile
#include <assert.h>
#include <immintrin.h>
#include <math.h>
#include <stdckdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
namespace
{
const char debruijn[64] = {
0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28, 16, 3, 61,
54, 58, 35, 52, 50, 42, 21, 44, 38, 32, 29, 23, 17, 11, 4, 62,
46, 55, 26, 59, 40, 36, 15, 53, 34, 51, 20, 43, 31, 22, 10, 45,
25, 39, 14, 33, 19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63,
};
int
bsrll (long long x)
{
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
x |= x >> 32;
return debruijn[(x * 0x03f79d71b4cb0a89ull) >> 58];
}
unsigned
pun (float f)
{
union
{
float f;
unsigned i;
} u = { f };
return u.i;
}
float
pun (unsigned i)
{
union
{
unsigned i;
float f;
} u = { i };
return u.f;
}
long long
ulp (float a, float b)
{
long long ai = pun (a);
long long bi = pun (b);
long long di = ai - bi;
if (di < 0)
di = -di;
return di;
}
int
bad (float a, float b)
{
long long d = ulp (a, b);
return d ? bsrll (d) + 1 : 0;
}
float
float01 (unsigned x)
{ // (0,1)
return 1.f / 8388608 * ((x >> 9) + .5f);
}
float
numba (void)
{ // (-1,1)
return float01 (lemur64 ()) * 2 - 1;
}
struct timespec
now (void)
{
struct timespec ts;
clock_gettime (CLOCK_REALTIME, &ts);
return ts;
}
struct timespec
tub (struct timespec a, struct timespec b)
{
a.tv_sec -= b.tv_sec;
if (a.tv_nsec < b.tv_nsec)
{
a.tv_nsec += 1000000000;
a.tv_sec--;
}
a.tv_nsec -= b.tv_nsec;
return a;
}
long long
tano (struct timespec x)
{
long long ns;
if (!ckd_mul (&ns, x.tv_sec, 1000000000ul) && !ckd_add (&ns, ns, x.tv_nsec))
return ns;
else if (x.tv_sec < 0)
return INT64_MIN;
else
return INT64_MAX;
}
#ifdef __FMA__
#define MADD128(x, y, z) _mm_fmadd_ps (x, y, z)
#define NMADD128(x, y, z) _mm_fnmadd_ps (x, y, z)
#else
#define MADD128(x, y, z) _mm_add_ps (_mm_mul_ps (x, y), z)
#define NMADD128(x, y, z) _mm_sub_ps (z, _mm_mul_ps (x, y))
#endif
#ifdef __SSE2__
__m128
v_expf_sse2 (__m128 x)
{
const __m128 r = _mm_set1_ps (0x1.8p23f);
const __m128 z = MADD128 (x, _mm_set1_ps (0x1.715476p+0f), r);
const __m128 n = _mm_sub_ps (z, r);
const __m128 b = NMADD128 (n, _mm_set1_ps (0x1.7f7d1cp-20f),
NMADD128 (n, _mm_set1_ps (0x1.62e4p-1f), x));
const __m128i e = _mm_slli_epi32 (_mm_castps_si128 (z), 23);
const __m128 k = _mm_castsi128_ps (
_mm_add_epi32 (e, _mm_castps_si128 (_mm_set1_ps (1))));
const __m128i c = _mm_castps_si128 (
_mm_cmpgt_ps (_mm_andnot_ps (_mm_set1_ps (-0.f), n), _mm_set1_ps (126)));
const __m128 u = _mm_mul_ps (b, b);
const __m128 j = MADD128 (MADD128 (MADD128 (_mm_set1_ps (0x1.0e4020p-7f), b,
_mm_set1_ps (0x1.573e2ep-5f)),
u,
MADD128 (_mm_set1_ps (0x1.555e66p-3f), b,
_mm_set1_ps (0x1.fffdb6p-2f))),
u, _mm_mul_ps (_mm_set1_ps (0x1.ffffecp-1f), b));
if (!_mm_movemask_epi8 (c))
return MADD128 (j, k, k);
const __m128i g
= _mm_and_si128 (_mm_castps_si128 (_mm_cmple_ps (n, _mm_setzero_ps ())),
_mm_set1_epi32 (0x82000000u));
const __m128 s1
= _mm_castsi128_ps (_mm_add_epi32 (g, _mm_set1_epi32 (0x7f000000u)));
const __m128 s2 = _mm_castsi128_ps (_mm_sub_epi32 (e, g));
const __m128i d = _mm_castps_si128 (
_mm_cmpgt_ps (_mm_andnot_ps (_mm_set1_ps (-0.f), n), _mm_set1_ps (192)));
return _mm_or_ps (
_mm_and_ps (_mm_castsi128_ps (d), _mm_mul_ps (s1, s1)),
_mm_andnot_ps (
_mm_castsi128_ps (d),
_mm_or_ps (
_mm_and_ps (_mm_castsi128_ps (c),
_mm_mul_ps (MADD128 (s2, j, s2), s1)),
_mm_andnot_ps (_mm_castsi128_ps (c), MADD128 (k, j, k)))));
}
#endif // __SSE2__
#ifdef __AVX2__
__m256
v_expf_avx2 (__m256 x)
{
const __m256 r = _mm256_set1_ps (0x1.8p23f);
const __m256 z = _mm256_fmadd_ps (x, _mm256_set1_ps (0x1.715476p+0f), r);
const __m256 n = _mm256_sub_ps (z, r);
const __m256 b = _mm256_fnmadd_ps (
n, _mm256_set1_ps (0x1.7f7d1cp-20f),
_mm256_fnmadd_ps (n, _mm256_set1_ps (0x1.62e4p-1f), x));
const __m256i e = _mm256_slli_epi32 (_mm256_castps_si256 (z), 23);
const __m256 k = _mm256_castsi256_ps (
_mm256_add_epi32 (e, _mm256_castps_si256 (_mm256_set1_ps (1))));
const __m256i c = _mm256_castps_si256 (
_mm256_cmp_ps (_mm256_andnot_ps (_mm256_set1_ps (-0.f), n),
_mm256_set1_ps (126), _CMP_GT_OQ));
const __m256 u = _mm256_mul_ps (b, b);
const __m256 j = _mm256_fmadd_ps (
_mm256_fmadd_ps (_mm256_fmadd_ps (_mm256_set1_ps (0x1.0e4020p-7f), b,
_mm256_set1_ps (0x1.573e2ep-5f)),
u,
_mm256_fmadd_ps (_mm256_set1_ps (0x1.555e66p-3f), b,
_mm256_set1_ps (0x1.fffdb6p-2f))),
u, _mm256_mul_ps (_mm256_set1_ps (0x1.ffffecp-1f), b));
if (!_mm256_movemask_ps (_mm256_castsi256_ps (c)))
return _mm256_fmadd_ps (j, k, k);
const __m256i g
= _mm256_and_si256 (_mm256_castps_si256 (_mm256_cmp_ps (
n, _mm256_setzero_ps (), _CMP_LE_OQ)),
_mm256_set1_epi32 (0x82000000u));
const __m256 s1 = _mm256_castsi256_ps (
_mm256_add_epi32 (g, _mm256_set1_epi32 (0x7f000000u)));
const __m256 s2 = _mm256_castsi256_ps (_mm256_sub_epi32 (e, g));
const __m256i d = _mm256_castps_si256 (
_mm256_cmp_ps (_mm256_andnot_ps (_mm256_set1_ps (-0.f), n),
_mm256_set1_ps (192), _CMP_GT_OQ));
return _mm256_or_ps (
_mm256_and_ps (_mm256_castsi256_ps (d), _mm256_mul_ps (s1, s1)),
_mm256_andnot_ps (
_mm256_castsi256_ps (d),
_mm256_or_ps (
_mm256_and_ps (_mm256_castsi256_ps (c),
_mm256_mul_ps (_mm256_fmadd_ps (s2, j, s2), s1)),
_mm256_andnot_ps (_mm256_castsi256_ps (c),
_mm256_fmadd_ps (k, j, k)))));
}
#endif // __AVX2__
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512
v_expf_avx512 (__m512 x)
{
const __m512 r = _mm512_set1_ps (0x1.8p23f);
const __m512 z = _mm512_fmadd_ps (x, _mm512_set1_ps (0x1.715476p+0f), r);
const __m512 n = _mm512_sub_ps (z, r);
const __m512 b = _mm512_fnmadd_ps (
n, _mm512_set1_ps (0x1.7f7d1cp-20f),
_mm512_fnmadd_ps (n, _mm512_set1_ps (0x1.62e4p-1f), x));
const __m512i e = _mm512_slli_epi32 (_mm512_castps_si512 (z), 23);
const __m512 k = _mm512_castsi512_ps (
_mm512_add_epi32 (e, _mm512_castps_si512 (_mm512_set1_ps (1))));
const __mmask16 c = _mm512_cmp_ps_mask (_mm512_abs_ps (n),
_mm512_set1_ps (126), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps (b, b);
const __m512 j = _mm512_fmadd_ps (
_mm512_fmadd_ps (_mm512_fmadd_ps (_mm512_set1_ps (0x1.0e4020p-7f), b,
_mm512_set1_ps (0x1.573e2ep-5f)),
u,
_mm512_fmadd_ps (_mm512_set1_ps (0x1.555e66p-3f), b,
_mm512_set1_ps (0x1.fffdb6p-2f))),
u, _mm512_mul_ps (_mm512_set1_ps (0x1.ffffecp-1f), b));
if (_mm512_kortestz (c, c))
return _mm512_fmadd_ps (j, k, k);
const __m512i g
= _mm512_and_si512 (_mm512_movm_epi32 (_mm512_cmp_ps_mask (
n, _mm512_setzero_ps (), _CMP_LE_OQ)),
_mm512_set1_epi32 (0x82000000u));
const __m512 s1 = _mm512_castsi512_ps (
_mm512_add_epi32 (g, _mm512_set1_epi32 (0x7f000000u)));
const __m512 s2 = _mm512_castsi512_ps (_mm512_sub_epi32 (e, g));
const __mmask16 d = _mm512_cmp_ps_mask (_mm512_abs_ps (n),
_mm512_set1_ps (192), _CMP_GT_OQ);
return _mm512_mask_blend_ps (
d,
_mm512_mask_blend_ps (c, _mm512_fmadd_ps (k, j, k),
_mm512_mul_ps (_mm512_fmadd_ps (s2, j, s2), s1)),
_mm512_mul_ps (s1, s1));
}
#endif
} // namespace
#define N 512
_Alignas (64) float A[N];
_Alignas (64) float B[N];
void
run_expf (void)
{
for (int i = 0; i < N; ++i)
B[i] = expf (B[i]);
}
void
run_v_expf_sse2 (void)
{
for (int i = 0; i < N; i += 4)
_mm_storeu_ps (B + i, v_expf_sse2 (_mm_loadu_ps (A + i)));
}
#ifdef __AVX2__
void
run_v_expf_avx2 (void)
{
for (int i = 0; i < N; i += 8)
_mm256_storeu_ps (B + i, v_expf_avx2 (_mm256_loadu_ps (A + i)));
}
#endif
#ifdef __AVX2__
extern "C" __m256 libmvec_expf_avx2 (__m256);
void
run_libmvec_expf_avx2 (void)
{
for (int i = 0; i < N; i += 8)
_mm256_storeu_ps (B + i, libmvec_expf_avx2 (_mm256_loadu_ps (A + i)));
}
#endif
#ifdef __AVX512F__
void
run_v_expf_avx512 (void)
{
for (int i = 0; i < N; i += 16)
_mm512_storeu_ps (B + i, v_expf_avx512 (_mm512_loadu_ps (A + i)));
}
#endif
void
nothing (void)
{
}
void (*barrier) (void) = nothing;
#define BENCH(ITERATIONS, WORK_PER_RUN, CODE) \
do \
{ \
struct timespec start = now (); \
for (int i = 0; i < ITERATIONS; ++i) \
{ \
barrier (); \
CODE; \
} \
long long work = WORK_PER_RUN * ITERATIONS; \
double nanos = (tano (tub (now (), start)) + work - 1) / (double)work; \
printf ("%10g ns %2dx %s\n", nanos, ITERATIONS, #CODE); \
} \
while (0)
int
main (int argc, char *argv[])
{
printf ("\n");
int i = 0;
A[i++] = +0.;
A[i++] = -0.;
A[i++] = +NAN;
A[i++] = -NAN;
A[i++] = +INFINITY;
A[i++] = -INFINITY;
A[i++] = 87;
A[i++] = 88;
A[i++] = 88.7229f;
A[i++] = 89;
A[i++] = -87;
A[i++] = -90;
A[i++] = -95;
A[i++] = -100;
A[i++] = -104;
for (; i < N; ++i)
A[i] = numba ();
BENCH (2000, N, run_expf ());
BENCH (2000, N, run_v_expf_sse2 ());
#ifdef __AVX2__
BENCH (2000, N, run_libmvec_expf_avx2 ());
BENCH (2000, N, run_v_expf_avx2 ());
#endif
#ifdef __AVX512F__
BENCH (2000, N, run_v_expf_avx512 ());
#endif
printf ("\n");
printf ("//%12s %12s %12s %5s\n", "input", "exp", "v", "bad");
printf ("//%12s %12s %12s %5s\n", "=====", "===", "=========", "===");
for (int i = 0; i < 23; ++i)
printf ("//%12g %12g %12g %5d\n", A[i], expf (A[i]), B[i],
bad (exp (A[i]), B[i]));
#define MAX_ERROR_ULP 2
#pragma omp parallel for
for (long i = 0; i < 4294967296; ++i)
{
union
{
unsigned i;
float f;
} a, b, u = { i };
a.f = expf (u.f);
b.f = v_expf_sse2 (__m128{ u.f })[0];
#ifdef __AVX2__
union
{
unsigned i;
float f;
} c;
c.f = v_expf_avx2 (__m256{ u.f })[0];
#endif
#ifdef __AVX512F__
union
{
unsigned i;
float f;
} d;
d.f = v_expf_avx512 (__m512{ u.f })[0];
#endif
long ai = a.i;
long bi = b.i;
#ifdef __AVX2__
long ci = c.i;
#endif
#ifdef __AVX512F__
long di = d.i;
#endif
long e = bi - ai;
if (e < 0)
e = -e;
if (e > MAX_ERROR_ULP)
exit (66);
#ifdef __AVX2__
long f = ci - ai;
if (f < 0)
f = -f;
if (f > MAX_ERROR_ULP)
exit (77);
#endif
#ifdef __AVX512F__
long g = di - ai;
if (g < 0)
g = -g;
if (g > MAX_ERROR_ULP)
exit (88);
#endif
}
printf ("%ld numbers tested successfully\n", 4294967296);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment