Skip to content

Instantly share code, notes, and snippets.

@danking
Created December 9, 2020 19:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danking/f619d0931658e3514e6adf701b6df0eb to your computer and use it in GitHub Desktop.
Save danking/f619d0931658e3514e6adf701b6df0eb to your computer and use it in GitHub Desktop.
import hail as hl
import numpy as np
def tsqr(mt: hl.MatrixTable, field: str, *, block_size: int = 1024):
A = hl.experimental.mt_to_table_of_ndarray(mt[field], block_size=block_size)
A = A.add_index('partition_index')
A = A.annotate(r_and_q = hl.nd.qr(A.ndarray))
A = A.annotate(q = A.r_and_q[0])
A = A.annotate(r = A.r_and_q[1])
Qs = A.select('partition_index', 'q')
Rs = A.select('partition_index', 'r')
R_as_one_tall_skinny_matrix = Rs.aggregate(hl.nd.vstack(hl.agg.collect(Rs.r)))
q_twiddle, r_twiddle = np.linalg.qr(R_as_one_tall_skinny_matrix)
return Qs, q_twiddle, r_twiddle
def local_af(scores: hl.Table, scores_field: str, k: int,
mt: hl.MatrixTable, gt_field: str, *, block_size: int = 1024):
assert k <= block_size
n_rows, n_cols = mt.count()
scores = scores.annotate(**{
scores_field: scores[scores_field].map(lambda x: hl.struct(x=x))})
scores = scores.annotate_globals(cols=hl.range(10).map(lambda i: hl.struct(i=hl.str(i))))
scores = scores._unlocalize_entries(scores_field, 'cols', ['i'])
q_partitions, q_twiddle, r_twiddle = tsqr(scores, 'x')
q_twiddle_slice = hl.literal(q_twiddle)[
q_partitions.partition_index*k:(q_partitions.partition_index+1)*k, :]
q_partitions = q_partitions.annotate(
q_final = q_partitions.q @ q_twiddle_slice)
q = q_partitions.aggregate(
hl.nd.vstack(hl.agg.collect(q_partitions.q_final)), _localize=False)
mt = mt.annotate_entries(x = mt[gt_field].n_alt_alleles())
col_key = list(mt.col_key)
mt = mt.localize_entries('entries', 'cols')
mt = mt.annotate(local_af_nd = q @ (q.T @ hl.nd.array(mt.entries.x)))
local_afs = hl.range(n_cols).map(lambda i: mt.local_af_nd[i])
mt = mt.select(
entries=hl.zip(mt.entries, local_afs).map(
lambda pair:
pair[0].annotate(local_af=pair[1])))
mt = mt._unlocalize_entries('entries', 'cols', col_key)
return mt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment