Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
#!/usr/bin/env python
import os
import argparse
import subprocess
import csv
import sys
def parse_args():
parser = argparse.ArgumentParser(description='run cft bin/blast.py for seed clusters from partis yaml files specified in an input csv. write results to specified outdir.')
parser.add_argument(
'input_csv', type=str,
help='csv that should have the following colunms: queries,partis_output,partis_glfo_dir,locus,blast_output,seed_uid')
parser.add_argument(
'--top-n-hits', type=str, default=None,
help='output tsv with this many of the top hits for each query.')
return parser.parse_args()
def run_blast(input_dict, args):
outdir = os.path.join(input_dict['blast_output'], input_dict['seed_uid'])
db_seqs_file = os.path.join(outdir, 'db_seqs.fa')
#1. extract cluster fasta
extract_fasta_args = ['partis/bin/extract-fasta.py',
'--input-file', input_dict['partis_output'],
'--fasta-output-file', db_seqs_file,
'--seed-unique-id', input_dict['seed_uid']]
if os.path.splitext(input_dict['partis_output'])[1] == '.csv':
extract_fasta_args += ['--glfo-dir', input_dict['partis_glfo_dir'],
'--locus', input_dict['locus']]
print(' '.join(extract_fasta_args))
subprocess.call(extract_fasta_args)
#2. run blast
blast_args = ['python', 'bin/blast.py',
db_seqs_file,
input_dict['queries'],
'--outdir', outdir,
'--write-query-alignments',
'--results-basename', input_dict['seed_uid']]
if args.top_n_hits is not None:
blast_args += ['--top-n-hits', args.top_n_hits]
print(' '.join(blast_args))
subprocess.call(blast_args)
return os.path.join(outdir, input_dict['seed_uid'] + '.blastn.top_{}_hits.tsv'.format(args.top_n_hits))
def tsv_dict_lines(tsvfname, info_to_add={}):
with open(tsvfname) as tsvfile:
reader = csv.DictReader(tsvfile, delimiter='\t')
return [dict(info_to_add.items() + d.items()) for d in reader], info_to_add.keys() + reader.fieldnames
def main():
args = parse_args()
with open(args.input_csv) as csvfile:
rows = csv.DictReader(csvfile)
master_output_rows = []
for row in rows:
top_n_hits_tsv = run_blast(row, args)
output_rows, fieldnames = tsv_dict_lines(top_n_hits_tsv, info_to_add={'seed_uid': row['seed_uid']})
master_output_rows += output_rows
project_dir = row['blast_output']
with open(os.path.join(project_dir, os.path.split(project_dir)[1] + '.tsv'), 'w') as master_tsv:
writer = csv.DictWriter(master_tsv, fieldnames=fieldnames, delimiter='\t')
writer.writeheader()
for l in master_output_rows:
writer.writerow(l)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment