Skip to content

Instantly share code, notes, and snippets.

@bosborne
Last active January 18, 2022 17:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bosborne/7b0db23e68204ef437af77db85b98df8 to your computer and use it in GitHub Desktop.
Save bosborne/7b0db23e68204ef437af77db85b98df8 to your computer and use it in GitHub Desktop.
Use `mash` k-mer analysis to create a decreased redundancy sequence protein sequence file
#!/usr/bin/env python3
import argparse
import os
import subprocess
import sys
from shutil import which
import tempfile
from collections import defaultdict
from Bio import SeqIO
from pathlib import Path
from sklearn.cluster import DBSCAN
'''
Run the `mash` application to do pairwise protein sequence comparisons using kmers, and run
DBSCAN from scikit-learn with the resulting distance data to identify clusters of closely related
sequences that are removed to create a "decreased redundancy" ("dr") file. If the input file
is in Swissprot format, the removed sequences will be the ones with the fewest GO terms and the
output file will also be in Swissprot format. If the input file is in fasta format related sequences
are arbitrarily removed and the output file will be in fasta format.
Example using a file from Uniprot:
> time python3 preprocessing-scripts/make_dr_seqs.py -c 16 -s data/uniprot_sprot_viruses.dat
real 6m15.980s
Input file 'data/uniprot_sprot_viruses.dat' has 17039 sequences
Output file 'uniprot_sprot_viruses-dr.dat' has 10572 sequences
2137 clusters found and 6,467 sequences removed (62% of the sequences) using a EC2 r5a.4xlarge (16 cores).
'''
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--seqfile", required=True,
help="Path to input sequence file")
parser.add_argument("--informat", default="swiss",
help="Input sequence file format")
parser.add_argument("-t", "--threshold", default=0.1,
type=float, help="Distance threshold")
parser.add_argument("-o", "--output", help="Output sequence file name")
parser.add_argument("-c", "--cores", default='2', type=str,
help="Number of cores for 'mash dist'")
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose")
args = parser.parse_args()
def main():
builder = MakeDRSeqs(args.seqfile, args.informat, args.threshold,
args.verbose, args.output, args.cores)
builder.check_for_mash()
builder.make_dr()
class MakeDRSeqs:
def __init__(self, seqfile, informat, threshold,
verbose, output, cores) -> None:
self.seqfile = seqfile
self.informat = informat
self.threshold = threshold
self.verbose = verbose
self.output = output
self.cores = cores
def make_dr(self):
mash_out = self.run_mash()
mash_dict = self.read_file(mash_out)
mat, ids = self.make_dist_matrix(mash_dict)
clusters = self.find_dbscan_clusters(mat, ids)
seqs_to_remove = self.get_seqs_to_remove(clusters)
self.write_seqs(seqs_to_remove)
'''
Tab-delimited data from `mash dist`:
Q91G40 O55709 1 1 0/269
Q6GZQ9 O55709 1 1 0/397
Q6GZX4 Q6GZX4 0 0 248/248
Q6GZX5 Q6GZX4 0 0 248/248
Q6GZX6 Q6GZX4 0.00135228 0 245/251
Q6GZX7 Q6GZX4 0.00598258 0 235/261
0 is no distance (all kmers in common), 1 is no kmers in common.
'''
def run_mash(self):
# Create fasta format file if input file is not fasta format
if self.informat != 'fasta':
tmpfasta = tempfile.NamedTemporaryFile()
if self.verbose:
print("Creating fasta version of '{}'".format(self.seqfile))
SeqIO.convert(self.seqfile, self.informat, tmpfasta.name, "fasta")
fastafile = tmpfasta.name
else:
fastafile = self.seqfile
tmph = tempfile.NamedTemporaryFile(delete=False)
tmpout = open(tmph.name, 'w')
# The same file acts as both query and reference
cmd = ['mash', 'dist', '-i', '-a', '-p', self.cores, '-d',
str(self.threshold), fastafile, fastafile]
try:
if self.verbose:
print("Running 'mash dist': {}".format(cmd))
proc = subprocess.run(cmd, stdout=tmpout)
except (subprocess.CalledProcessError) as exception:
print("Error: {}".format(exception))
sys.exit("Error running 'mash dist' on {}".format(tmpfasta.name))
tmpout.close()
if self.verbose:
print("Completed 'mash dist', output is: {}".format(tmph.name))
return tmph.name
""" Create a dict of dicts for the distances """
def read_file(self, mash_out):
mash_sorted = defaultdict(dict)
with open(mash_out, "r") as f:
# mash_sorted = { i[0]:{i[1]:float(i[2])} for i in [l.split("\t") for l in f] }
for l in f:
arr = l.split("\t")
mash_sorted[arr[0]][arr[1]] = float(arr[2])
return mash_sorted
'''
Example of 4 points in 2 clusters (a,b and c,d) as a square distance matrix:
a b c d
a [0, 0.1, 1, 1],
b [0.1, 0, 1, 1],
c [1, 1, 0, 0.1],
d [1, 1 , 0.1, 0]
'''
def make_dist_matrix(self, mash_dict):
if self.verbose:
print("Making distance matrix")
# Make a dict with the position of each sequence in the matrix:
# {CATH_HUMAN':0, CYS1_DICDI':1 ....}
ids = {seqid: count for count,
seqid in enumerate(sorted(mash_dict.keys()))}
# Create a square matrix filled with 1's since most values could be 1
mat = [[1 for col in range(len(ids))] for row in range(len(ids))]
# Insert non-1 mash distances into the prepopulated matrix
for count, seq1 in enumerate(sorted(mash_dict.keys())):
for seq2 in mash_dict[seq1].keys():
mat[count][ids[seq2]] = mash_dict[seq1][seq2]
if self.verbose:
print("Completed distance matrix")
return mat, ids
'''
Find clusters using DBSCAN and mash distances.
Example using the square distance matrix above:
>>> from sklearn.cluster import DBSCAN
>>> clust = DBSCAN(eps=0.1,min_samples=2,metric='precomputed')
>>> m = [[0,0.1,1,1],[0.1,0,1,1],[1,1,0,0.1],[1,1,0.1,0]]
>>> clust.fit_predict(m)
array([0, 0, 1, 1])
'''
def find_dbscan_clusters(self, mat, ids):
clust = DBSCAN(eps=0.1, min_samples=2, metric='precomputed')
if self.verbose:
print("Running DBSCAN on matrix with {} sequences".format(
len(mat[0])))
predictions = clust.fit_predict(mat)
if self.verbose:
print("Completed DBSCAN")
clusters = defaultdict(list)
for count, i in enumerate(predictions):
# Ignore -1 scores, not in any cluster
if str(i) != '-1':
clusters[i].append(self.find_seqid(count, ids))
if len(clusters) == 0:
if self.verbose:
print("No clusters found")
sys.exit()
if self.verbose:
for k in clusters.keys():
print("Cluster {0} ({1}): {2}".format(
k, len(clusters[k]), clusters[k]))
for seqid in clusters[k]:
print(self.seqs[seqid].description)
return clusters
""" Find key given a value """
def find_seqid(self, count, ids):
# [seqid for seqid, num in ids.items() if count == num][0]
for seqid, num in ids.items():
if count == num:
return seqid
""" Remove the sequences in a cluster with the fewest GO terms if input is Swissprot """
def get_seqs_to_remove(self, clusters):
if self.verbose:
print("Reading {} with SeqIO".format(self.seqfile))
self.seqs = SeqIO.index(self.seqfile, self.informat)
seqs_to_remove = list()
for cluster in clusters.values():
if self.informat == 'swiss':
# Compare numbers of GO terms in pairs of sequences, start with first pair
for num in range(0, len(cluster) - 1):
skip = cluster[num] if self.get_num_terms(self.seqs[cluster[num]]) < self.get_num_terms(
self.seqs[cluster[(num + 1)]]) else cluster[(num + 1)]
seqs_to_remove.append(skip)
else:
# Choose arbitrary sequences to remove
for num in range(0, len(cluster) - 1):
seqs_to_remove.append(cluster[num])
return seqs_to_remove
""" Write output file, same format as input file """
def write_seqs(self, seqs_to_remove):
if self.verbose:
print("Input file '{0}' has {1} sequences".format(
self.seqfile, len(self.seqs.keys())))
# For example, input is "data/viruses.dat", output is "viruses-dr.dat"
self.output = os.path.basename(self.seqfile).split('.')[0]
+ '-dr.dat' if self.informat == 'swiss' else os.path.basename(
self.seqfile).split('.')[0] + '-dr.fa'
with open(self.output, 'w') as out:
n = 0
for seqid in self.seqs.keys():
if seqid not in seqs_to_remove:
n += 1
# Have to use get_raw() since Biopython cannot write 'swiss' format
out.write(self.seqs.get_raw(seqid).decode())
if self.verbose:
print("Output file '{0}' has {1} sequences".format(self.output, n))
'''
>seq.dbxrefs
['EMBL:AY548484', 'RefSeq:YP_031579.1', 'SwissPalm:Q6GZX4', 'GeneID:2947773', 'KEGG:vg:2947773',
'Proteomes:UP000008770', 'GO:GO:0046782', 'InterPro:IPR007031', 'Pfam:PF04947']
'''
def get_num_terms(self, seq):
return len([t for t in seq.dbxrefs if t.startswith('GO:')])
""" Check whether `mash` is in PATH and is executable. """
def check_for_mash(self):
if which('mash') is None:
sys.exit("'mash' is not installed or not in PATH")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment