Skip to content

Instantly share code, notes, and snippets.

@shibatch
Created October 18, 2020 13:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shibatch/42f82c885f62bc2c063d6500844f937f to your computer and use it in GitHub Desktop.
Save shibatch/42f82c885f62bc2c063d6500844f937f to your computer and use it in GitHub Desktop.
Half precision function prototype
#include <stdio.h>
#include <stdint.h>
#include <math.h>
#include <x86intrin.h>
__m256i ftoh256h(__m256i l, __m256i h) {
__m256i rl = _mm256_shuffle_epi8(l, (__m256i) { 0x0d0c090805040100l, -1l, 0x0d0c090805040100l, -1l });
rl = _mm256_permute4x64_epi64(rl, (0 << 0) | (2 << 2) | (1 << 4) | (3 << 6));
__m256i rh = _mm256_shuffle_epi8(h, (__m256i) { 0x0d0c090805040100l, -1l, 0x0d0c090805040100l, -1l });
rh = _mm256_permute4x64_epi64(rh, (1 << 0) | (3 << 2) | (0 << 4) | (2 << 6));
return _mm256_blend_epi32(rl, rh, 0xf0);
}
__m256i ftoh256l(__m256i l, __m256i h) {
__m256i rl = _mm256_shuffle_epi8(l, (__m256i) { 0x0f0e0b0a07060302l, -1l, 0x0f0e0b0a07060302l, -1l });
rl = _mm256_permute4x64_epi64(rl, (0 << 0) | (2 << 2) | (1 << 4) | (3 << 6));
__m256i rh = _mm256_shuffle_epi8(h, (__m256i) { 0x0f0e0b0a07060302l, -1l, 0x0f0e0b0a07060302l, -1l });
rh = _mm256_permute4x64_epi64(rh, (1 << 0) | (3 << 2) | (0 << 4) | (2 << 6));
return _mm256_blend_epi32(rl, rh, 0xf0);
}
__m256i htof256l0(__m256i a) {
__m256i b = _mm256_permute4x64_epi64(a, (2 << 0) | (3 << 2) | (0 << 4) | (1 << 6));
__m256i sa = _mm256_shuffle_epi8(a, (__m256i) {
0x0000030200000100l, 0x0000070600000504l, 0x00000b0a00000908l, 0x00000f0e00000d0cl });
__m256i sb = _mm256_shuffle_epi8(b, (__m256i) {
0x0000030200000100l, 0x0000070600000504l, 0x00000b0a00000908l, 0x00000f0e00000d0cl });
sa = _mm256_and_si256(sa, (__m256i) {
0x0000ffff0000ffffl, 0x0000ffff0000ffffl, 0x0000ffff0000ffffl, 0x0000ffff0000ffffl });
sb = _mm256_and_si256(sb, (__m256i) {
0x0000ffff0000ffffl, 0x0000ffff0000ffffl, 0x0000ffff0000ffffl, 0x0000ffff0000ffffl });
return _mm256_blend_epi32(sa, sb, 0xf0);
}
__m256i htof256l1(__m256i b) {
__m256i a = _mm256_permute4x64_epi64(b, (2 << 0) | (3 << 2) | (0 << 4) | (1 << 6));
__m256i sa = _mm256_shuffle_epi8(a, (__m256i) {
0x0000030200000100l, 0x0000070600000504l, 0x00000b0a00000908l, 0x00000f0e00000d0cl });
__m256i sb = _mm256_shuffle_epi8(b, (__m256i) {
0x0000030200000100l, 0x0000070600000504l, 0x00000b0a00000908l, 0x00000f0e00000d0cl });
sa = _mm256_and_si256(sa, (__m256i) {
0x0000ffff0000ffffl, 0x0000ffff0000ffffl, 0x0000ffff0000ffffl, 0x0000ffff0000ffffl });
sb = _mm256_and_si256(sb, (__m256i) {
0x0000ffff0000ffffl, 0x0000ffff0000ffffl, 0x0000ffff0000ffffl, 0x0000ffff0000ffffl });
return _mm256_blend_epi32(sa, sb, 0xf0);
}
void show256(char *mes, __m256i a) {
printf("%s ", mes);
for(int i=0;i<16;i++) {
union {
__m256i v;
int16_t a[16];
} cnv = { .v = a };
union {
int16_t i[2];
float f;
} cnvf;
cnvf.i[0] = 0;
cnvf.i[1] = cnv.a[i];
printf("%g ", cnvf.f);
}
printf("\n");
}
//
double theFunc(double x) { return 0.5 * x * (1 + erf(x / sqrt(2))); }
static int32_t table[0x10000];
void init() {
for(int i=0;i<0x10000;i++) {
union {
int16_t i[2];
float f;
} cnvf;
cnvf.i[0] = 0;
cnvf.i[1] = i;
cnvf.f = theFunc(cnvf.f);
table[i] = cnvf.i[1];
}
}
__m256i lookup(__m256i a) {
return ftoh256h(_mm256_i32gather_epi32(table, htof256l0(a), 4),
_mm256_i32gather_epi32(table, htof256l1(a), 4));
}
int main(void) {
init();
float args[] = {
0.11f, 0.22f, 0.33f, 0.44f, 0.55f, 0.66f, 0.77f, 0.88f,
1.11f, 1.22f, 1.33f, 1.44f, 1.55f, 1.66f, 1.77f, 1.88f
};
__m256 fl0 = _mm256_loadu_ps(&args[0]);
__m256 fh0 = _mm256_loadu_ps(&args[8]);
__m256i v = lookup(ftoh256l((__m256i)fl0, (__m256i)fh0));
show256("val", v);
printf("cmp ");
for(int i=0;i<16;i++) printf("%g ", theFunc(args[i]));
}
@shibatch
Copy link
Author

ftoh256l is a function to convert two vector variables in float format to one vector variable in bfloat16 format. It lacks rounding at this point.
lookup does the table lookup. It effectively calculate the value of 0.5 * x * (1 + erf(x / sqrt(2))) in bfloat16 format.
show256 shows the values in bfloat16 format that are stored in the given vector.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment