Skip to content

Instantly share code, notes, and snippets.

@jeromekelleher
Last active October 14, 2022 16:30
Show Gist options
  • Save jeromekelleher/33927b941f2d63317049aacce16ff63a to your computer and use it in GitHub Desktop.
Save jeromekelleher/33927b941f2d63317049aacce16ff63a to your computer and use it in GitHub Desktop.
Compute branch GRM using numba
import sys
import tskit
import numpy as np
import numba
@numba.njit
def sv_tables_init(parent_array):
# This is an implementation of Schieber and Vishkin's nearest common ancestor
# algorithm from TAOCP volume 4A, pg.164-167 [K11]_. Preprocesses the
# input tree into a sideways heap in O(n) time and processes queries for the
# nearest common ancestor between an arbitary pair of nodes in O(1) time.
#
# NB internally this assumes that tree uses 1-based addressing and 0 is a
# special value. We would like to update this to use the 0-based indexing
# natively and also use the built-in triply linked tree to save some time
# and memory.
n = 1 + parent_array.shape[0]
oriented_forest = np.zeros(n, dtype=np.int32)
# Convert to 1-based representation assumed here.
oriented_forest[1:] = 1 + parent_array
LAMBDA = 0
# Triply-linked tree. FIXME we shouldn't need to build this as it's
# available already in tskit
child = np.zeros(n, dtype=np.int32)
parent = np.zeros(n, dtype=np.int32)
sib = np.zeros(n, dtype=np.int32)
for u in range(n):
v = oriented_forest[u]
sib[u] = child[v]
child[v] = u
parent[u] = v
lambd = np.zeros(n, dtype=np.int32)
pi = np.zeros(n, dtype=np.int32)
tau = np.zeros(n, dtype=np.int32)
beta = np.zeros(n, dtype=np.int32)
alpha = np.zeros(n, dtype=np.int32)
p = child[LAMBDA]
n = 0
lambd[0] = -1
while p != LAMBDA:
while True:
n += 1
pi[p] = n
tau[n] = LAMBDA
lambd[n] = 1 + lambd[n >> 1]
if child[p] != LAMBDA:
p = child[p]
else:
break
beta[p] = n
while True:
tau[beta[p]] = parent[p]
if sib[p] != LAMBDA:
p = sib[p]
break
else:
p = parent[p]
if p != LAMBDA:
h = lambd[n & -pi[p]]
beta[p] = ((n >> h) | 1) << h
else:
break
# Begin the second traversal
lambd[0] = lambd[n]
pi[LAMBDA] = 0
beta[LAMBDA] = 0
alpha[LAMBDA] = 0
p = child[LAMBDA]
while p != LAMBDA:
while True:
a = alpha[parent[p]] | (beta[p] & -beta[p])
alpha[p] = a
if child[p] != LAMBDA:
p = child[p]
else:
break
while True:
if sib[p] != LAMBDA:
p = sib[p]
break
else:
p = parent[p]
if p == LAMBDA:
break
return lambd, pi, tau, beta, alpha
@numba.njit
def _sv_mrca(x, y, lambd, pi, tau, beta, alpha):
if beta[x] <= beta[y]:
h = lambd[beta[y] & -beta[x]]
else:
h = lambd[beta[x] & -beta[y]]
k = alpha[x] & alpha[y] & -(1 << h)
h = lambd[k & -k]
j = ((beta[x] >> h) | 1) << h
if j == beta[x]:
xhat = x
else:
ell = lambd[alpha[x] & ((1 << h) - 1)]
xhat = tau[((beta[x] >> ell) | 1) << ell]
if j == beta[y]:
yhat = y
else:
ell = lambd[alpha[y] & ((1 << h) - 1)]
yhat = tau[((beta[y] >> ell) | 1) << ell]
if pi[xhat] <= pi[yhat]:
z = xhat
else:
z = yhat
return z
@numba.njit
def sv_mrca(x, y, lambd, pi, tau, beta, alpha):
# Convert to 1-based indexes and back. See note above.
return _sv_mrca(x + 1, y + 1, lambd, pi, tau, beta, alpha) - 1
@numba.njit
def _B_matrix_sv(I, parent, time, root_time):
# Preprocess so that we can answer MRCA queries in constant time.
lambd, pi, tau, beta, alpha = sv_tables_init(parent)
N = I.shape[0]
B = np.zeros((N, N))
for j in range(N):
for k in range(j, N):
s = 0
for u in I[j]:
for v in I[k]:
mrca = sv_mrca(u, v, lambd, pi, tau, beta, alpha)
s += root_time - time[mrca]
B[j, k] = s
B[k, j] = s
return B
@numba.njit
def _normalise(B):
K = np.zeros_like(B)
N = K.shape[0]
B_mean = np.mean(B)
# Numba doesn't support np.mean(a, axis=0)
Bi_mean = np.zeros(N)
for i in range(N):
for j in range(N):
Bi_mean[i] += B[i, j]
Bi_mean /= N
for i in range(N):
for j in range(N):
K[i, j] = B[i, j] - Bi_mean[i] - Bi_mean[j] + B_mean
return K
def branch_genetic_relatedness_matrix(ts):
N = ts.num_individuals
I = np.zeros((N, 2), dtype=np.int32)
for ind in ts.individuals():
I[ind.id] = ind.nodes
K = np.zeros((ts.num_individuals, ts.num_individuals))
for tree in ts.trees():
if tree.num_roots == ts.num_samples:
continue
root_time = ts.nodes_time[tree.root]
B = _B_matrix_sv(I, tree.parent_array, ts.nodes_time, root_time)
K += _normalise(B) * tree.span
return K
if __name__ == "__main__":
if len(sys.argv) != 3:
print(f"usage: {sys.argv[0]} file.trees relatedness.txt")
sys.exit(1)
ts = tskit.load(sys.argv[1])
K = branch_genetic_relatedness_matrix(ts)
# K2 = genetic_relatedness_matrix(ts)
# assert np.allclose(K, K2)
np.savetxt(sys.argv[2], K)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment