Last active
November 7, 2019 01:17
-
-
Save brentp/11a497dd5961c8340ddf05b89fef85a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bitops | |
{.passC:"-mavx512f -mavx512vl -mavx512bw -mpopcnt"} | |
{.passL:"-mavx512f -mavx512vl -mavx512bw -mpopcnt"} | |
when defined(vcc): | |
{.pragma: x86, noDecl, header:"<intrin.h>".} | |
else: | |
{.pragma: x86, noDecl, header:"<x86intrin.h>".} | |
type veci8* {.importc: "__m512i", incompleteStruct, header:"<x86intrin.h>".} = object #array[8, int64] | |
type veci4* {.importc: "__m256i", incompleteStruct, header:"<x86intrin.h>".} = object #array[4, int64] | |
proc load*(mem_addr: pointer): veci8 {.importc: "_mm512_loadu_si512", header:"<x86intrin.h>".} | |
proc store*(mem_addr: pointer, src:veci8) {.importc: "_mm512_storeu_si512", header:"<x86intrin.h>".} | |
{.emit: """ | |
// https://github.com/kimwalisch/libpopcnt | |
static inline long long int popcnt512(__m512i v) { | |
__m512i m1 = _mm512_set1_epi8(0x55); | |
__m512i m2 = _mm512_set1_epi8(0x33); | |
__m512i m4 = _mm512_set1_epi8(0x0F); | |
__m512i t1 = _mm512_sub_epi8(v, (_mm512_srli_epi16(v, 1) & m1)); | |
__m512i t2 = _mm512_add_epi8(t1 & m2, (_mm512_srli_epi16(t1, 2) & m2)); | |
m2 = _mm512_add_epi8(t2, _mm512_srli_epi16(t2, 4)) & m4; | |
return _mm512_reduce_add_epi64(_mm512_sad_epu8(m2, _mm512_setzero_si512())); | |
} | |
""" .} | |
proc toVec*(a:var seq[int64], i:SomeInteger=0): veci8 {.inline.} = | |
let p = a[i].addr.pointer | |
result = load(p) | |
proc toSeq*(src:veci8, a:var seq[int64], i:SomeInteger=0) = | |
let p = a[i].addr.pointer | |
p.store(src) | |
proc `and`*(a, b: veci8): veci8 {.importc: "_mm512_and_si512", x86.} | |
proc `or`*(a, b: veci8): veci8 {.importc: "_mm512_or_si512", x86.} | |
proc `+`*(a, b: veci8): veci8 {.importc: "_mm512_add_epi64", x86.} | |
#proc popcnt*(a: veci8): veci8 {.importc: "_mm512_popcnt_epi64", x86.} | |
proc popcnt*(a: veci8): int64 {.importc: "popcnt512".} | |
#proc popcnt*(a: veci8): int64 {.inline.} = | |
# let b = cast[array[8, int64]](a) | |
# return b[0].countSetBits + b[1].countSetBits + b[2].countSetBits + b[3].countSetBits + | |
# b[4].countSetBits + b[5].countSetBits + b[6].countSetBits + b[7].countSetBits | |
proc sum*(a:veci8): int64 {.importc: "_mm512_reduce_add_epi64", x86.} | |
import random | |
proc random_seq(size:int): seq[int64] = | |
result = newSeq[int64](size) | |
for i in 0..result.high: | |
result[i] = random(0..int64.high.int) | |
when isMainModule: | |
import times | |
randomize() | |
var t0 = cpuTime() | |
let N = 330 # e.g. the ~20K sites in somalier | |
let loops = 3_500_000 # e.g. on 2.5M pairs of samples | |
var A_hom_ref = random_seq(N) | |
var A_het = random_seq(N) | |
var A_hom_alt = random_seq(N) | |
var B_hom_ref = random_seq(N) | |
var B_het = random_seq(N) | |
var B_hom_alt = random_seq(N) | |
# vectorized AVX512 implementation: | |
t0 = cpuTime() | |
var vibs0_0 = 0'i64 | |
var vibs2_0 = 0'i64 | |
var vnhets_0 = 0'i64 | |
var vnN_0 = 0'i64 | |
var vibs0_1 = 0'i64 | |
var vibs2_1 = 0'i64 | |
var vnhets_1 = 0'i64 | |
var vnN_1 = 0'i64 | |
var vibs0_2 = 0'i64 | |
var vibs2_2 = 0'i64 | |
var vnhets_2 = 0'i64 | |
var vnN_2 = 0'i64 | |
var vibs0_3 = 0'i64 | |
var vibs2_3 = 0'i64 | |
var vnhets_3 = 0'i64 | |
var vnN_3 = 0'i64 | |
for t in 0..<loops: | |
if t == 100: | |
t0 = cpuTime() | |
for i in countup(0, N, 32): | |
var A_hr = A_hom_ref.toVec(i) | |
var A_heti = A_het.toVec(i) | |
var A_ha = A_hom_alt.toVec(i) | |
var B_hr = B_hom_ref.toVec(i) | |
var B_heti = B_het.toVec(i) | |
var B_ha = B_hom_alt.toVec(i) | |
vibs0_0 += ((Ahr and Bha) or (Bhr and Aha)).popcnt | |
vibs2_0 += ((Ahr and Bhr) or (Bha and Aha) or (Aheti and Bheti)).popcnt | |
vnhets_0 += (A_heti or B_heti).popcnt | |
vnN_0 += ((A_hr or A_heti or A_ha) and (B_hr or B_heti or B_ha)).popcnt | |
A_hr = A_hom_ref.toVec(i+8) | |
A_heti = A_het.toVec(i+8) | |
A_ha = A_hom_alt.toVec(i+8) | |
B_hr = B_hom_ref.toVec(i+8) | |
B_heti = B_het.toVec(i+8) | |
B_ha = B_hom_alt.toVec(i+8) | |
vibs0_1 += ((Ahr and Bha) or (Bhr and Aha)).popcnt | |
vibs2_1 += ((Ahr and Bhr) or (Bha and Aha) or (Aheti and Bheti)).popcnt | |
vnhets_1 += (A_heti or B_heti).popcnt | |
vnN_1 += ((A_hr or A_heti or A_ha) and (B_hr or B_heti or B_ha)).popcnt | |
A_hr = A_hom_ref.toVec(i+16) | |
A_heti = A_het.toVec(i+16) | |
A_ha = A_hom_alt.toVec(i+16) | |
B_hr = B_hom_ref.toVec(i+16) | |
B_heti = B_het.toVec(i+16) | |
B_ha = B_hom_alt.toVec(i+16) | |
vibs0_2 += ((Ahr and Bha) or (Bhr and Aha)).popcnt | |
vibs2_2 += ((Ahr and Bhr) or (Bha and Aha) or (Aheti and Bheti)).popcnt | |
vnhets_2 += (A_heti or B_heti).popcnt | |
vnN_2 += ((A_hr or A_heti or A_ha) and (B_hr or B_heti or B_ha)).popcnt | |
A_hr = A_hom_ref.toVec(i+24) | |
A_heti = A_het.toVec(i+24) | |
A_ha = A_hom_alt.toVec(i+24) | |
B_hr = B_hom_ref.toVec(i+24) | |
B_heti = B_het.toVec(i+24) | |
B_ha = B_hom_alt.toVec(i+24) | |
vibs0_3 += ((Ahr and Bha) or (Bhr and Aha)).popcnt | |
vibs2_3 += ((Ahr and Bhr) or (Bha and Aha) or (Aheti and Bheti)).popcnt | |
vnhets_3 += (A_heti or B_heti).popcnt | |
vnN_3 += ((A_hr or A_heti or A_ha) and (B_hr or B_heti or B_ha)).popcnt | |
echo cpuTime() - t0, " AVX512 got ibs0:", vibs0_0 + vibs0_1 + vibs0_2 + vibs0_3, " ibs2:", vibs2_0 + vibs2_1 + vibs2_2 + vibs2_3, " hets:", vnhets_0 + vnhets_1 + vnhets_2 + vnhets_3, " N:", vnN_0 + vnN_1 + vnN_2 + vnN_3 | |
# "normal", 64 bits at-a-time implementation | |
t0 = cpuTime() | |
var ibs0 = 0 | |
var ibs2 = 0 | |
var nhets = 0'i64 | |
var nN = 0'i64 | |
for t in 0..<loops: | |
if t == 100: | |
t0 = cpuTime() | |
for i in 0..<N: | |
var i0 = (A_hom_ref[i] and B_hom_alt[i]) or (A_hom_alt[i] and B_hom_ref[i]) | |
ibs0 += i0.countSetBits | |
var i2 = (A_het[i] and B_het[i]) or (A_hom_alt[i] and B_hom_alt[i]) or (A_hom_ref[i] and B_hom_ref[i]) | |
ibs2 += i2.countSetBits | |
nhets += (A_het[i] or B_het[i]).countSetBits | |
nN += ((A_hom_ref[i] or A_het[i] or A_hom_alt[i]) and (B_hom_ref[i] or B_het[i] or B_hom_alt[i])).countSetBits | |
echo cpuTime() - t0, " got ibs0:", ibs0, " ibs2:", ibs2, " hets:", nhets, " N:", nN | |
#doAssert ibs2 == vibs2 | |
#doAssert ibs0 == vibs0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment