Skip to content

Instantly share code, notes, and snippets.

@jeromekelleher
Created October 14, 2022 16:31
Show Gist options
  • Save jeromekelleher/653f333d65d8fcd88ffc8a108b54f55d to your computer and use it in GitHub Desktop.
Save jeromekelleher/653f333d65d8fcd88ffc8a108b54f55d to your computer and use it in GitHub Desktop.
import sys
import tskit
import numpy as np
import numba
@numba.njit
def _normalise(B):
K = np.zeros_like(B)
N = K.shape[0]
B_mean = np.mean(B)
# Numba doesn't support np.mean(a, axis=0)
Bi_mean = np.zeros(N)
for i in range(N):
for j in range(N):
Bi_mean[i] += B[i, j]
Bi_mean /= N
for i in range(N):
for j in range(N):
K[i, j] = B[i, j] - Bi_mean[i] - Bi_mean[j] + B_mean
return K
@numba.njit
def _update_B_matrix(B, area, samples, nodes_individual):
n = samples.shape[0]
for j in range(n):
v = samples[j]
V = nodes_individual[v]
B[V, V] += area
for k in range(j + 1, n):
w = samples[k]
W = nodes_individual[w]
B[V, W] += area
B[W, V] += area
def B_matrix_incremental(ts):
N = ts.num_individuals
B = np.zeros((N, N))
last_update = np.zeros(ts.num_nodes)
time = ts.nodes_time
parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1
# The Descendent nodes of individuals at each node in the tree.
# Note: we could probably use the SampleLists in the C code to
# do this, but we'd have to check the sample <-> Individual mapping.
D = [set() for _ in range(ts.num_nodes)]
for ind in ts.individuals():
for u in ind.nodes:
D[u].add(u)
def update_matrix(u, distance):
if parent[u] != -1:
branch_length = time[parent[u]] - time[u]
area = branch_length * distance
samples = np.array(list(D[u]))
_update_B_matrix(B, area, samples, ts.nodes_individual)
for (left, right), edges_out, edges_in in ts.edge_diffs(include_terminal=True):
for edge in edges_out:
u = edge.child
update_matrix(u, left - last_update[u])
last_update[u] = left
parent[edge.child] = -1
u = edge.parent
while u != -1:
update_matrix(u, left - last_update[u])
last_update[u] = left
D[u] -= D[edge.child]
u = parent[u]
for edge in edges_in:
parent[edge.child] = edge.parent
u = edge.child
update_matrix(u, left - last_update[u])
last_update[u] = left
u = edge.parent
while u != -1:
D[u] |= D[edge.child]
update_matrix(u, left - last_update[u])
last_update[u] = left
u = parent[u]
return B
def branch_genetic_relatedness_matrix(ts):
B = B_matrix_incremental(ts)
return _normalise(B)
if __name__ == "__main__":
if len(sys.argv) != 3:
print(f"usage: {sys.argv[0]} file.trees relatedness.txt")
sys.exit(1)
ts = tskit.load(sys.argv[1])
K = branch_genetic_relatedness_matrix(ts)
# K2 = genetic_relatedness_matrix(ts)
# assert np.allclose(K, K2)
np.savetxt(sys.argv[2], K)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment