Multiple methods for finding best set of mutually compatible BLAST hits
#!/usr/bin/env python3 | |
''' | |
For each query listed in a BLAST XML results file, determine how much of the | |
query sequence is covered by the "best" hit. The best hit is deemed to be the | |
one in which the summed bitscores of its compatible HSPs is highest. Mutually | |
compatible HSPs are deemed to be ones that neither overlap nor "cross" each | |
other. | |
Usage: | |
cat blast_results.xml | find-best-hit.py | |
Output: | |
Values in range [0, 1], with one per line. The nth such value represents the | |
fraction of query n covered by the mutually compatible HSPs composing its | |
best hit. | |
''' | |
import sys | |
import xml.etree.ElementTree as ET | |
from collections import defaultdict, namedtuple | |
import json | |
from functools import lru_cache | |
from enum import Enum | |
import unittest | |
import os | |
class Path(Enum): | |
without_self = 1 | |
with_self = 2 | |
class BaseHspSet(object): | |
def __init__(self, hsps): | |
self._hsps = sorted(hsps, key = lambda a: a.query_start) | |
class DumbHspSet(BaseHspSet): | |
def find_best_combo(self): | |
highest_score = 0 | |
best_combo = None | |
for valid_combo in self._compute_hsp_combos(self._hsps): | |
score = sum([i.bitscore for i in valid_combo]) | |
if score > highest_score: | |
highest_score = score | |
best_combo = valid_combo | |
return (highest_score, best_combo) | |
def _compute_hsp_combos(self, L): | |
sorted_hsps = sorted(L, key = lambda a: a.query_start) | |
# Only look at subset of hits to make this computationally tractable. | |
return self._combo_r([], sorted_hsps) | |
def _combo_r(self, base, remaining): | |
if len(remaining) == 0: | |
return [base] | |
first = remaining[0] | |
rest = remaining[1:] | |
if len(base) > 0: | |
base_query_end = base[-1].query_end | |
base_subject_end = base[-1].subject_end | |
else: | |
base_query_end = 0 | |
base_subject_end = 0 | |
if first.query_start > base_query_end and first.subject_start > base_subject_end: | |
return self._combo_r(base + [first], rest) + self._combo_r(base, rest) | |
else: | |
return self._combo_r(base, rest) | |
class SmartHspSet(BaseHspSet): | |
# Segment HSPs into smaller, independent sets. No performance benefit appears | |
# in practice -- using segment takes almost exactly as long as not using it | |
# -- and so don't bother. | |
def _segment(self): | |
''' | |
Return partition of self._hsps, such that no element (i.e., subset of | |
self._hsps) contains an HSP that overlaps another HSP from a different | |
element. This is useful because the problem of finding the highest-scoring | |
subset of HSPs can then be solved independently for each subset forming the | |
partition without impairing the answer's optimality. | |
''' | |
segments = [] | |
i = 0 | |
n = len(self._hsps) | |
while i < n: | |
candidate = self._hsps[i] | |
max_query_end = candidate.query_end | |
max_subject_end = candidate.subject_end | |
conflicting_query_end = candidate.query_end | |
conflicting_subject_end = candidate.subject_end | |
last_conflicting = i | |
for j in range(i + 1, n): | |
other = self._hsps[j] | |
max_query_end = max(max_query_end, other.query_end) | |
max_subject_end = max(max_subject_end, other.subject_end) | |
if other.query_start <= conflicting_query_end \ | |
or other.subject_start <= conflicting_subject_end: | |
conflicting_query_end = max_query_end | |
conflicting_subject_end = max_subject_end | |
last_conflicting = j | |
next_compatible = last_conflicting + 1 | |
segments.append(tuple(range(i, next_compatible))) | |
i = next_compatible | |
return segments | |
@lru_cache(maxsize=None) | |
def _find_compatible(self, target, indices): | |
'''Find all hsps that don't overlap or cross `target`. Such HSPs are not | |
guaranteed to be mutually compatible, but are guaranteed only to be | |
compatible with `target`.''' | |
compatible = [] | |
target_hsp = self._hsps[target] | |
for index in indices: | |
# Don't define target as overlapping itself. | |
if index == target: | |
continue | |
test_hsp = self._hsps[index] | |
if target_hsp.query_start <= test_hsp.query_start: | |
first, second = target_hsp, test_hsp | |
else: | |
first, second = test_hsp, target_hsp | |
overlap = second.query_start <= first.query_end | |
overlap = overlap or second.subject_start <= first.subject_end | |
if not overlap: | |
compatible.append(index) | |
return tuple(compatible) | |
def _reconstruct_soln(self, indices): | |
return self._reconstruct_soln_r(indices, []) | |
def _reconstruct_soln_r(self, indices, soln): | |
if len(indices) <= 1: | |
return indices | |
indices = tuple(indices) | |
if self._combo_soln[indices] == Path.without_self: | |
return self._reconstruct_soln_r(indices[:-1], soln) | |
else: | |
last_index = indices[-1] | |
compatible = self._find_compatible(last_index, indices) | |
soln = self._reconstruct_soln_r(compatible, soln) | |
return soln + (last_index,) | |
def find_best_combo(self): | |
indices = range(len(self._hsps)) | |
self._best_values = {} | |
self._combo_soln = {} | |
score = self._find_best_combo_r(tuple(indices)) | |
soln = self._reconstruct_soln(indices) | |
soln = [self._hsps[i] for i in soln] | |
return (score, soln) | |
@lru_cache(maxsize=None) | |
def _find_best_combo_r(self, indices): | |
# Must check for len == 0, which will occur if we're looking at only two | |
# hsps that are incompatible with each other. When looking at the | |
# second, the call self._find_best_combo_r(compatible) will recurse on an | |
# empty tuple (since "compatible" doesn't include the self, and it won't | |
# include the preceding (incompatible) hsp). | |
if len(indices) == 0: | |
return 0 | |
if len(indices) == 1: | |
self._combo_soln[indices] = Path.with_self | |
return self._hsps[indices[0]].bitscore | |
last_index = indices[-1] | |
without_self_indices = indices[:-1] | |
# This order retains sort order, since indices[-1] will be after all other | |
# elements in indices. | |
compatible = self._find_compatible(last_index, indices) | |
best_without_self = self._find_best_combo_r(without_self_indices) | |
best_with_self = self._find_best_combo_r(compatible) + self._hsps[indices[-1]].bitscore | |
if best_without_self >= best_with_self: | |
self._best_values[indices] = best_without_self | |
self._combo_soln[indices] = Path.without_self | |
else: | |
self._best_values[indices] = best_with_self | |
self._combo_soln[indices] = Path.with_self | |
return self._best_values[indices] | |
Query = namedtuple('Query', [ | |
'query_def', | |
'query_length', | |
'hits', | |
]) | |
Hit = namedtuple('Hit', [ | |
'subject_length', | |
'subject_def', | |
'hsps', | |
]) | |
Hsp = namedtuple('Hsp', [ | |
'query_start', | |
'query_end', | |
'query_frame', | |
'subject_start', | |
'subject_end', | |
'subject_frame', | |
'alignment_length', | |
'bitscore', | |
]) | |
class BlastParser(object): | |
def _determine_strand(self, frame): | |
if frame > 0: | |
return '+' | |
elif frame < 0: | |
return '-' | |
else: | |
return '.' | |
def parse(self, blast_xml): | |
tree = ET.parse(blast_xml) | |
root = tree.getroot() | |
queries = [] | |
for iteration_elem in root.find('BlastOutput_iterations').iter('Iteration'): | |
query = Query( | |
query_length = int(iteration_elem.find('Iteration_query-len').text), | |
query_def = iteration_elem.find('Iteration_query-def').text, | |
hits = [] | |
) | |
hit_elems = iteration_elem.find('Iteration_hits') | |
if hit_elems is None: | |
continue | |
for hit_elem in hit_elems.iter('Hit'): | |
hit = Hit( | |
subject_length = int(hit_elem.find('Hit_len').text), | |
subject_def = hit_elem.find('Hit_def').text, | |
hsps = defaultdict(lambda: []), | |
) | |
for hsp_elem in hit_elem.find('Hit_hsps').iter('Hsp'): | |
hsp = Hsp( | |
query_frame = int(hsp_elem.find('Hsp_query-frame').text), | |
subject_frame = int(hsp_elem.find('Hsp_hit-frame').text), | |
alignment_length = int(hsp_elem.find('Hsp_align-len').text), | |
query_start = int(hsp_elem.find('Hsp_query-from').text), | |
query_end = int(hsp_elem.find('Hsp_query-to').text), | |
subject_start = int(hsp_elem.find('Hsp_hit-from').text), | |
subject_end = int(hsp_elem.find('Hsp_hit-to').text), | |
bitscore = float(hsp_elem.find('Hsp_bit-score').text), | |
) | |
# Ensure sequence start coordinate is always less than sequence end | |
# coordinate (which they won't be for sequences on minus strand). | |
for seq_type in ('query', 'subject'): | |
if getattr(hsp, '%s_frame' % seq_type) < 0: | |
start_key = '%s_start' % seq_type | |
end_key = '%s_end' % seq_type | |
start = getattr(hsp, start_key) | |
end = getattr(hsp, end_key) | |
if seq_type == 'query': | |
length = query.query_length | |
elif seq_type == 'subject': | |
length = hit.subject_length | |
if start > end: | |
tmp = start | |
start = end | |
end = tmp | |
hsp = hsp._replace(**{ | |
start_key: length - end + 1, | |
end_key: length - start + 1, | |
}) | |
key = self._determine_strand(hsp.query_frame) + \ | |
self._determine_strand(hsp.subject_frame) | |
hit.hsps[key].append(hsp) | |
query.hits.append(hit) | |
queries.append(query) | |
return queries | |
class Scorer(object): | |
def print_query_scores(self, queries): | |
for query in queries: | |
best_query_proportion = 0 | |
for hit in query.hits: | |
for strand_combo, segregated_hsps in hit.hsps.items(): | |
query_proportion = self._calc_query_proportion(segregated_hsps, query.query_length) | |
if query_proportion > best_query_proportion: | |
best_query_proportion = query_proportion | |
print(best_query_proportion) | |
class CompatibleHspScorer(Scorer): | |
def _calc_query_proportion(self, hsps, query_len): | |
#hsp_set = DumbHspSet(hsps) | |
hsp_set = SmartHspSet(hsps) | |
score, soln = hsp_set.find_best_combo() | |
hsps_size = sum([h.query_end - h.query_start + 1 for h in soln]) | |
return hsps_size / query_len | |
class HspUnionScorer(Scorer): | |
def _calc_query_proportion(self, hsps, query_len): | |
intervals = [(hsp.query_start, hsp.query_end) for hsp in hsps] | |
interval_union = self._compute_interval_union(intervals) | |
size_sum = self._compute_interval_size_sum(interval_union) | |
return size_sum / query_len | |
def _compute_interval_union(self, intervals): | |
''' | |
Compute union of overlapping intervals. Assume both endpoints are inclusive. | |
''' | |
points = [] | |
for start, end in intervals: | |
points.append((start, 'start')) | |
points.append((end, 'end')) | |
points.sort(key=lambda p: p[0]) | |
# Collapse overlapping intervals. | |
open_intervals = 0 | |
current_open_start = None | |
interval_union = [] | |
for coord, ptype in points: | |
if ptype == 'start': | |
if open_intervals == 0: | |
current_open_start = coord | |
open_intervals += 1 | |
elif ptype == 'end': | |
open_intervals -= 1 | |
if open_intervals == 0: | |
interval_union.append((current_open_start, coord)) | |
else: | |
raise Exception('Unknown point type: %s' % ptype) | |
# Merge adjacent intervals. | |
merged_union = [] | |
open_start, last_end = interval_union[0] | |
for start, end in interval_union[1:]: | |
assert start >= last_end | |
# Last interval and current interval are not immediately adjacent. | |
if start > last_end + 1: | |
merged_union.append((open_start, last_end)) | |
open_start = start | |
last_end = end | |
merged_union.append((open_start, last_end)) | |
return merged_union | |
def _compute_interval_size_sum(self, intervals): | |
interval_sum = 0 | |
for start, end in intervals: | |
interval_sum += end - start + 1 | |
return interval_sum | |
class TestSmartHspSet(unittest.TestCase): | |
def test_simple(self): | |
hsps = [ | |
Hsp(query_start=119, query_end=249, subject_start=1157, subject_end=1287, bitscore=243.031, subject_frame=1, query_frame=1, alignment_length=5), | |
Hsp(query_start=244, query_end=345, subject_start=845, subject_end=946, bitscore=189.479, subject_frame=1, query_frame=1, alignment_length=5), | |
Hsp(query_start=17, query_end=118, subject_start=1901, subject_end=2002, bitscore=189.479, subject_frame=1, query_frame=1, alignment_length=5) | |
] | |
hsp_set = SmartHspSet(hsps) | |
score, soln = hsp_set.find_best_combo() | |
self.assertAlmostEqual(243.031, score) | |
self.assertEqual([hsps[0]], soln) | |
def test_simple_two(self): | |
hsps = [ | |
Hsp(query_start=346, query_end=854, subject_start=9427, subject_end=9935, bitscore=470.169, subject_frame=1, query_frame=1, alignment_length=5), | |
Hsp(query_start=1, query_end=339, subject_start=231024, subject_end=231362, bitscore=178.399, subject_frame=1, query_frame=1, alignment_length=5) | |
] | |
hsp_set = SmartHspSet(hsps) | |
score, soln = hsp_set.find_best_combo() | |
self.assertAlmostEqual(470.169, score) | |
self.assertEqual([hsps[0]], soln) | |
def _test_against_blast_xml(self, xml_filename, strand_pair, hsp_len, expected_score, expected_soln): | |
xml_path = os.path.join(os.path.dirname(__file__), 'test-cases', xml_filename) | |
queries = BlastParser().parse(xml_path) | |
self.assertEqual(1, len(queries)) | |
query = queries[0] | |
self.assertEqual(1, len(query.hits)) | |
hit = query.hits[0] | |
self.assertEqual(1, len(hit.hsps.keys())) | |
hsps = hit.hsps[strand_pair] | |
self.assertEqual(hsp_len, len(hsps)) | |
hsp_set = SmartHspSet(hsps) | |
score, soln = hsp_set.find_best_combo() | |
self.assertAlmostEqual(expected_score, score) | |
self.assertEqual(expected_soln, soln) | |
def test_reverse_complement(self): | |
self._test_against_blast_xml( | |
xml_filename = 'reverse_complement.blast.xml', | |
strand_pair = '+-', | |
hsp_len = 2, | |
expected_score = 505.562, | |
expected_soln = [ Hsp( | |
query_start=5, | |
query_end=8, | |
query_frame=1, | |
subject_start=22, | |
subject_end=22, | |
subject_frame=-1, | |
alignment_length=275, | |
bitscore=505.562) ]) | |
def test_problematic(self): | |
self._test_against_blast_xml( | |
xml_filename = 'problematic.blast.xml', | |
strand_pair = '++', | |
hsp_len = 4, | |
expected_score = 1699.312, | |
expected_soln = [ Hsp( | |
query_start=1631, | |
query_end=1949, | |
query_frame=1, | |
subject_start=1, | |
subject_end=319, | |
subject_frame=1, | |
alignment_length=275, | |
bitscore=573.582 | |
), | |
Hsp( | |
query_start=2687, | |
query_end=3406, | |
query_frame=1, | |
subject_start=373, | |
subject_end=1075, | |
subject_frame=1, | |
alignment_length=275, | |
bitscore=1125.73 | |
) ]) | |
class TestHspUnionScorer(unittest.TestCase): | |
def test_basic(self): | |
cases = ( | |
# Basic example | |
( | |
[(1, 16), (7, 15), (5, 10), (13, 16), (19, 21), (19, 20)], | |
[(1, 16), (19, 21)], | |
19 | |
), | |
# Correctly handle endpoints, assuming start-inclusive, end-exclusive | |
( | |
[(1, 15), (15, 20)], | |
[(1, 20)], | |
20 | |
), | |
# Correctly handle endpoints, assuming start-inclusive, end-inclusive | |
( | |
[(1, 15), (16, 20)], | |
[(1, 20)], | |
20 | |
), | |
# Ensure overlapping interavls handled properly | |
( | |
[(1, 15), (14, 20)], | |
[(1, 20)], | |
20 | |
) | |
) | |
for intervals, expected_union, expected_size_sum in cases: | |
scorer = HspUnionScorer() | |
result_union = scorer._compute_interval_union(intervals) | |
result_size_sum = scorer._compute_interval_size_sum(result_union) | |
self.assertEqual(expected_union, result_union) | |
self.assertEqual(expected_size_sum, result_size_sum) | |
def test_smart_vs_dumb_algorithm(queries): | |
decimal_places = 7 | |
for query in queries: | |
for hit in query.hits: | |
for strand_combo, segregated_hsps in hit.hsps.items(): | |
# Dumb algorithm takes exponential time (given 2^n combinations), so avoid | |
# computing more than 2^20 combos. | |
if len(segregated_hsps) > 20: | |
continue | |
smart_hsp_set = SmartHspSet(segregated_hsps) | |
dumb_hsp_set = DumbHspSet(segregated_hsps) | |
smart_score, smart_soln = hsp_set.find_best_combo() | |
dumb_score, dumb_soln = dumb_hsp_set.find_best_combo() | |
if round(dumb_score - smart_score, decimal_places) != 0: | |
from pprint import pprint | |
pprint(segregated_hsps, stream=sys.stderr) | |
raise Exception('Scores unequal (smart=%s, dumb=%s)' % (smart_score, dumb_score)) | |
def main(): | |
class RunType(Enum): | |
normal = 1 | |
compare_dp_vs_dumb = 2 | |
test_cases = 3 | |
run_type = RunType.normal | |
#run_type = RunType.compare_dp_vs_dumb | |
#run_type = RunType.test_cases | |
if run_type == RunType.normal: | |
queries = BlastParser().parse(sys.stdin) | |
#scorer = HspUnionScorer() | |
scorer = CompatibleHspScorer() | |
scorer.print_query_scores(queries) | |
elif run_type == RunType.compare_dp_vs_dumb: | |
queries = BlastParser().parse(sys.stdin) | |
test_smart_vs_dumb_algorithm(queries) | |
elif run_type == RunType.test_cases: | |
unittest.main() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment