-
-
Save tpoterba/4f10d504aa5631dfe09c1684e8f33185 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hail as hl | |
from hail import MatrixTable, Table | |
from hail.utils import info | |
from hail.utils.java import Env | |
from hail.experimental.vcf_combiner.vcf_combiner import CombinerConfig, calculate_even_genome_partitioning | |
from hail.ir import MatrixKeyRowsBy | |
import uuid | |
import os | |
def combine_array_schemas(ts): | |
ts = ts.annotate(__entries = hl.range(hl.len(ts.data)) | |
.flatmap(lambda i: | |
hl.if_else( | |
hl.is_missing(ts.data[i]), | |
hl.range(hl.len(ts.g[i].__cols)).map(lambda j: hl.missing(ts.data[i].__entries.dtype.element_type)), | |
ts.data[i].__entries))).drop('data') | |
return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols))) | |
def localize(mt): | |
if isinstance(mt, MatrixTable): | |
return mt._localize_entries('__entries', '__cols') | |
return mt | |
def unlocalize(mt): | |
if isinstance(mt, Table): | |
return mt._unlocalize_entries('__entries', '__cols', ['s']) | |
return mt | |
def combine_array_intermediates(mts): | |
"""Merges gvcfs and/or sparse matrix tables | |
Parameters | |
---------- | |
mts : :obj:`List[Union[Table, MatrixTable]]` | |
The matrix tables (or localized versions) to combine | |
Returns | |
------- | |
:class:`.MatrixTable` | |
Notes | |
----- | |
All of the input tables/matrix tables must have the same partitioning. This | |
module provides no method of repartitioning data. | |
""" | |
ts = hl.Table.multi_way_zip_join([localize(mt) for mt in mts], 'data', 'g') | |
combined = combine_array_schemas(ts) | |
return unlocalize(combined) | |
def transform_vcf_input(mt): | |
mt = MatrixTable(MatrixKeyRowsBy(mt._mir, ['locus', 'alleles'], is_sorted=True)) | |
return mt.annotate_entries(vcf_info = mt.info) | |
def merge_many_tables(input_paths, | |
out_file: str, | |
tmp_path: str, | |
*, | |
branch_factor: int = 64, | |
batch_size: int = 64, | |
target_records: int = 16_000, | |
overwrite: bool = False, | |
reference_genome: str = 'default'): | |
"""Run the Hail Genotype Array VCF combiner, performing a hierarchical merge to create a combined matrix table. | |
alleles in the final output. | |
""" | |
tmp_path += f'/combiner-temporary/{uuid.uuid4()}/' | |
size = 50_000_000 # bp interval size on import | |
intervals = calculate_even_genome_partitioning(reference_genome, size) | |
info(f"Using {len(intervals)} intervals with default whole-genome size" | |
f" {size} as partitioning for GVCF import") | |
config = CombinerConfig(branch_factor=branch_factor, | |
batch_size=batch_size, | |
target_records=target_records) | |
plan = config.plan(len(input_paths)) | |
files_to_merge = input_paths | |
n_phases = len(plan.phases) | |
total_ops = len(files_to_merge) * n_phases | |
total_work_done = 0 | |
for phase_i, phase in enumerate(plan.phases): | |
phase_i += 1 # used for info messages, 1-indexed for readability | |
n_jobs = len(phase.jobs) | |
merge_str = 'matrix tables' | |
job_str = hl.utils.misc.plural('job', n_jobs) | |
info(f"Starting phase {phase_i}/{n_phases}, merging {len(files_to_merge)} {merge_str} in {n_jobs} {job_str}.") | |
file0 = hl.read_matrix_table(files_to_merge[0]) | |
intervals, intervals_dtype = file0._calculate_new_partitions(file0.count_rows() / config.target_records) | |
new_files_to_merge = [] | |
for job_i, job in enumerate(phase.jobs): | |
job_i += 1 # used for info messages, 1-indexed for readability | |
n_merges = len(job.merges) | |
merge_str = hl.utils.misc.plural('file', n_merges) | |
pct_total = 100 * job.input_total_size / total_ops | |
info( | |
f"Starting phase {phase_i}/{n_phases}, job {job_i}/{len(phase.jobs)} to create {n_merges} merged {merge_str}, corresponding to ~{pct_total:.1f}% of total I/O.") | |
merge_mts: List[MatrixTable] = [] | |
for merge in job.merges: | |
inputs = [files_to_merge[i] for i in merge.inputs] | |
mts = Env.spark_backend("table_combiner").read_multiple_matrix_tables(inputs, intervals, | |
intervals_dtype) | |
merge_mts.append(combine_array_intermediates(mts)) | |
if phase_i == n_phases: # final merge! | |
assert n_jobs == 1 | |
assert len(merge_mts) == 1 | |
[final_mt] = merge_mts | |
final_mt.write(out_file, overwrite=overwrite) | |
new_files_to_merge = [out_file] | |
info(f"Finished phase {phase_i}/{n_phases}, job {job_i}/{len(phase.jobs)}, 100% of total I/O finished.") | |
break | |
tmp = f'{tmp_path}_phase{phase_i}_job{job_i}/' | |
hl.experimental.write_matrix_tables(merge_mts, tmp, overwrite=True) | |
pad = len(str(len(merge_mts))) | |
new_files_to_merge.extend(tmp + str(n).zfill(pad) + '.mt' for n in range(len(merge_mts))) | |
total_work_done += job.input_total_size | |
info( | |
f"Finished {phase_i}/{n_phases}, job {job_i}/{len(phase.jobs)}, {100 * total_work_done / total_ops:.1f}% of total I/O finished.") | |
info(f"Finished phase {phase_i}/{n_phases}.") | |
files_to_merge = new_files_to_merge | |
assert files_to_merge == [out_file] | |
info("Finished!") | |
inputs = [] # paths to matrix tables as described | |
merge_many_tables(inputs, 'combined.mt', '/tmp', branch_factor=100, batch_size=25, overwrite=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment