Skip to content

Instantly share code, notes, and snippets.

@brentp
Last active November 7, 2019 01:17
Show Gist options
  • Save brentp/11a497dd5961c8340ddf05b89fef85a1 to your computer and use it in GitHub Desktop.
Save brentp/11a497dd5961c8340ddf05b89fef85a1 to your computer and use it in GitHub Desktop.
import bitops
{.passC:"-mavx512f -mavx512vl -mavx512bw -mpopcnt"}
{.passL:"-mavx512f -mavx512vl -mavx512bw -mpopcnt"}
when defined(vcc):
{.pragma: x86, noDecl, header:"<intrin.h>".}
else:
{.pragma: x86, noDecl, header:"<x86intrin.h>".}
type veci8* {.importc: "__m512i", incompleteStruct, header:"<x86intrin.h>".} = object #array[8, int64]
type veci4* {.importc: "__m256i", incompleteStruct, header:"<x86intrin.h>".} = object #array[4, int64]
proc load*(mem_addr: pointer): veci8 {.importc: "_mm512_loadu_si512", header:"<x86intrin.h>".}
proc store*(mem_addr: pointer, src:veci8) {.importc: "_mm512_storeu_si512", header:"<x86intrin.h>".}
{.emit: """
// https://github.com/kimwalisch/libpopcnt
static inline long long int popcnt512(__m512i v) {
__m512i m1 = _mm512_set1_epi8(0x55);
__m512i m2 = _mm512_set1_epi8(0x33);
__m512i m4 = _mm512_set1_epi8(0x0F);
__m512i t1 = _mm512_sub_epi8(v, (_mm512_srli_epi16(v, 1) & m1));
__m512i t2 = _mm512_add_epi8(t1 & m2, (_mm512_srli_epi16(t1, 2) & m2));
m2 = _mm512_add_epi8(t2, _mm512_srli_epi16(t2, 4)) & m4;
return _mm512_reduce_add_epi64(_mm512_sad_epu8(m2, _mm512_setzero_si512()));
}
""" .}
proc toVec*(a:var seq[int64], i:SomeInteger=0): veci8 {.inline.} =
let p = a[i].addr.pointer
result = load(p)
proc toSeq*(src:veci8, a:var seq[int64], i:SomeInteger=0) =
let p = a[i].addr.pointer
p.store(src)
proc `and`*(a, b: veci8): veci8 {.importc: "_mm512_and_si512", x86.}
proc `or`*(a, b: veci8): veci8 {.importc: "_mm512_or_si512", x86.}
proc `+`*(a, b: veci8): veci8 {.importc: "_mm512_add_epi64", x86.}
#proc popcnt*(a: veci8): veci8 {.importc: "_mm512_popcnt_epi64", x86.}
proc popcnt*(a: veci8): int64 {.importc: "popcnt512".}
#proc popcnt*(a: veci8): int64 {.inline.} =
# let b = cast[array[8, int64]](a)
# return b[0].countSetBits + b[1].countSetBits + b[2].countSetBits + b[3].countSetBits +
# b[4].countSetBits + b[5].countSetBits + b[6].countSetBits + b[7].countSetBits
proc sum*(a:veci8): int64 {.importc: "_mm512_reduce_add_epi64", x86.}
import random
proc random_seq(size:int): seq[int64] =
result = newSeq[int64](size)
for i in 0..result.high:
result[i] = random(0..int64.high.int)
when isMainModule:
import times
randomize()
var t0 = cpuTime()
let N = 312 # e.g. the ~20K sites in somalier
let loops = 2_500_000 # e.g. on 2.5M pairs of samples
var A_hom_ref = random_seq(N)
var A_het = random_seq(N)
var A_hom_alt = random_seq(N)
var B_hom_ref = random_seq(N)
var B_het = random_seq(N)
var B_hom_alt = random_seq(N)
# vectorized AVX512 implementation:
t0 = cpuTime()
var vibs0 = 0'i64
var vibs2 = 0'i64
var vnhets = 0'i64
var vnN = 0'i64
for t in 0..<loops:
for i in countup(0, N, 8):
var A_hr = A_hom_ref.toVec(i)
var A_heti = A_het.toVec(i)
var A_ha = A_hom_alt.toVec(i)
var B_hr = B_hom_ref.toVec(i)
var B_heti = B_het.toVec(i)
var B_ha = B_hom_alt.toVec(i)
vibs0 += ((Ahr and Bha) or (Bhr and Aha)).popcnt
vibs2 += ((Ahr and Bhr) or (Bha and Aha) or (Aheti and Bheti)).popcnt
vnhets += (A_heti or B_heti).popcnt
vnN += ((A_hr or A_heti or A_ha) and (B_hr or B_heti or B_ha)).popcnt
echo cpuTime() - t0, " AVX512 got ibs0:", vibs0, " ibs2:", vibs2, " hets:", vnhets, " N:", vnN
# "normal", 64 bits at-a-time implementation
t0 = cpuTime()
var ibs0 = 0
var ibs2 = 0
var nhets = 0'i64
var nN = 0'i64
for t in 0..<loops:
for i in 0..<N:
var i0 = (A_hom_ref[i] and B_hom_alt[i]) or (A_hom_alt[i] and B_hom_ref[i])
ibs0 += i0.countSetBits
var i2 = (A_het[i] and B_het[i]) or (A_hom_alt[i] and B_hom_alt[i]) or (A_hom_ref[i] and B_hom_ref[i])
ibs2 += i2.countSetBits
nhets += (A_het[i] or B_het[i]).countSetBits
nN += ((A_hom_ref[i] or A_het[i] or A_hom_alt[i]) and (B_hom_ref[i] or B_het[i] or B_hom_alt[i])).countSetBits
echo cpuTime() - t0, " got ibs0:", ibs0, " ibs2:", ibs2, " hets:", nhets, " N:", nN
doAssert ibs2 == vibs2
doAssert ibs0 == vibs0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment