Created
July 14, 2014 20:58
-
-
Save jwintersinger/d5cd37660dd878462823 to your computer and use it in GitHub Desktop.
Multiple methods for finding best set of mutually compatible BLAST hits
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
''' | |
For each query listed in a BLAST XML results file, determine how much of the | |
query sequence is covered by the "best" hit. The best hit is deemed to be the | |
one in which the summed bitscores of its compatible HSPs is highest. Mutually | |
compatible HSPs are deemed to be ones that neither overlap nor "cross" each | |
other. | |
Usage: | |
cat blast_results.xml | find-best-hit.py | |
Output: | |
Values in range [0, 1], with one per line. The nth such value represents the | |
fraction of query n covered by the mutually compatible HSPs composing its | |
best hit. | |
''' | |
import sys | |
import xml.etree.ElementTree as ET | |
from collections import defaultdict, namedtuple | |
import json | |
from functools import lru_cache | |
from enum import Enum | |
import unittest | |
import os | |
class Path(Enum): | |
without_self = 1 | |
with_self = 2 | |
class BaseHspSet(object): | |
def __init__(self, hsps): | |
self._hsps = sorted(hsps, key = lambda a: a.query_start) | |
class DumbHspSet(BaseHspSet): | |
def find_best_combo(self): | |
highest_score = 0 | |
best_combo = None | |
for valid_combo in self._compute_hsp_combos(self._hsps): | |
score = sum([i.bitscore for i in valid_combo]) | |
if score > highest_score: | |
highest_score = score | |
best_combo = valid_combo | |
return (highest_score, best_combo) | |
def _compute_hsp_combos(self, L): | |
sorted_hsps = sorted(L, key = lambda a: a.query_start) | |
# Only look at subset of hits to make this computationally tractable. | |
return self._combo_r([], sorted_hsps) | |
def _combo_r(self, base, remaining): | |
if len(remaining) == 0: | |
return [base] | |
first = remaining[0] | |
rest = remaining[1:] | |
if len(base) > 0: | |
base_query_end = base[-1].query_end | |
base_subject_end = base[-1].subject_end | |
else: | |
base_query_end = 0 | |
base_subject_end = 0 | |
if first.query_start > base_query_end and first.subject_start > base_subject_end: | |
return self._combo_r(base + [first], rest) + self._combo_r(base, rest) | |
else: | |
return self._combo_r(base, rest) | |
class SmartHspSet(BaseHspSet): | |
# Segment HSPs into smaller, independent sets. No performance benefit appears | |
# in practice -- using segment takes almost exactly as long as not using it | |
# -- and so don't bother. | |
def _segment(self): | |
''' | |
Return partition of self._hsps, such that no element (i.e., subset of | |
self._hsps) contains an HSP that overlaps another HSP from a different | |
element. This is useful because the problem of finding the highest-scoring | |
subset of HSPs can then be solved independently for each subset forming the | |
partition without impairing the answer's optimality. | |
''' | |
segments = [] | |
i = 0 | |
n = len(self._hsps) | |
while i < n: | |
candidate = self._hsps[i] | |
max_query_end = candidate.query_end | |
max_subject_end = candidate.subject_end | |
conflicting_query_end = candidate.query_end | |
conflicting_subject_end = candidate.subject_end | |
last_conflicting = i | |
for j in range(i + 1, n): | |
other = self._hsps[j] | |
max_query_end = max(max_query_end, other.query_end) | |
max_subject_end = max(max_subject_end, other.subject_end) | |
if other.query_start <= conflicting_query_end \ | |
or other.subject_start <= conflicting_subject_end: | |
conflicting_query_end = max_query_end | |
conflicting_subject_end = max_subject_end | |
last_conflicting = j | |
next_compatible = last_conflicting + 1 | |
segments.append(tuple(range(i, next_compatible))) | |
i = next_compatible | |
return segments | |
@lru_cache(maxsize=None) | |
def _find_compatible(self, target, indices): | |
'''Find all hsps that don't overlap or cross `target`. Such HSPs are not | |
guaranteed to be mutually compatible, but are guaranteed only to be | |
compatible with `target`.''' | |
compatible = [] | |
target_hsp = self._hsps[target] | |
for index in indices: | |
# Don't define target as overlapping itself. | |
if index == target: | |
continue | |
test_hsp = self._hsps[index] | |
if target_hsp.query_start <= test_hsp.query_start: | |
first, second = target_hsp, test_hsp | |
else: | |
first, second = test_hsp, target_hsp | |
overlap = second.query_start <= first.query_end | |
overlap = overlap or second.subject_start <= first.subject_end | |
if not overlap: | |
compatible.append(index) | |
return tuple(compatible) | |
def _reconstruct_soln(self, indices): | |
return self._reconstruct_soln_r(indices, []) | |
def _reconstruct_soln_r(self, indices, soln): | |
if len(indices) <= 1: | |
return indices | |
indices = tuple(indices) | |
if self._combo_soln[indices] == Path.without_self: | |
return self._reconstruct_soln_r(indices[:-1], soln) | |
else: | |
last_index = indices[-1] | |
compatible = self._find_compatible(last_index, indices) | |
soln = self._reconstruct_soln_r(compatible, soln) | |
return soln + (last_index,) | |
def find_best_combo(self): | |
indices = range(len(self._hsps)) | |
self._best_values = {} | |
self._combo_soln = {} | |
score = self._find_best_combo_r(tuple(indices)) | |
soln = self._reconstruct_soln(indices) | |
soln = [self._hsps[i] for i in soln] | |
return (score, soln) | |
@lru_cache(maxsize=None) | |
def _find_best_combo_r(self, indices): | |
# Must check for len == 0, which will occur if we're looking at only two | |
# hsps that are incompatible with each other. When looking at the | |
# second, the call self._find_best_combo_r(compatible) will recurse on an | |
# empty tuple (since "compatible" doesn't include the self, and it won't | |
# include the preceding (incompatible) hsp). | |
if len(indices) == 0: | |
return 0 | |
if len(indices) == 1: | |
self._combo_soln[indices] = Path.with_self | |
return self._hsps[indices[0]].bitscore | |
last_index = indices[-1] | |
without_self_indices = indices[:-1] | |
# This order retains sort order, since indices[-1] will be after all other | |
# elements in indices. | |
compatible = self._find_compatible(last_index, indices) | |
best_without_self = self._find_best_combo_r(without_self_indices) | |
best_with_self = self._find_best_combo_r(compatible) + self._hsps[indices[-1]].bitscore | |
if best_without_self >= best_with_self: | |
self._best_values[indices] = best_without_self | |
self._combo_soln[indices] = Path.without_self | |
else: | |
self._best_values[indices] = best_with_self | |
self._combo_soln[indices] = Path.with_self | |
return self._best_values[indices] | |
Query = namedtuple('Query', [ | |
'query_def', | |
'query_length', | |
'hits', | |
]) | |
Hit = namedtuple('Hit', [ | |
'subject_length', | |
'subject_def', | |
'hsps', | |
]) | |
Hsp = namedtuple('Hsp', [ | |
'query_start', | |
'query_end', | |
'query_frame', | |
'subject_start', | |
'subject_end', | |
'subject_frame', | |
'alignment_length', | |
'bitscore', | |
]) | |
class BlastParser(object): | |
def _determine_strand(self, frame): | |
if frame > 0: | |
return '+' | |
elif frame < 0: | |
return '-' | |
else: | |
return '.' | |
def parse(self, blast_xml): | |
tree = ET.parse(blast_xml) | |
root = tree.getroot() | |
queries = [] | |
for iteration_elem in root.find('BlastOutput_iterations').iter('Iteration'): | |
query = Query( | |
query_length = int(iteration_elem.find('Iteration_query-len').text), | |
query_def = iteration_elem.find('Iteration_query-def').text, | |
hits = [] | |
) | |
hit_elems = iteration_elem.find('Iteration_hits') | |
if hit_elems is None: | |
continue | |
for hit_elem in hit_elems.iter('Hit'): | |
hit = Hit( | |
subject_length = int(hit_elem.find('Hit_len').text), | |
subject_def = hit_elem.find('Hit_def').text, | |
hsps = defaultdict(lambda: []), | |
) | |
for hsp_elem in hit_elem.find('Hit_hsps').iter('Hsp'): | |
hsp = Hsp( | |
query_frame = int(hsp_elem.find('Hsp_query-frame').text), | |
subject_frame = int(hsp_elem.find('Hsp_hit-frame').text), | |
alignment_length = int(hsp_elem.find('Hsp_align-len').text), | |
query_start = int(hsp_elem.find('Hsp_query-from').text), | |
query_end = int(hsp_elem.find('Hsp_query-to').text), | |
subject_start = int(hsp_elem.find('Hsp_hit-from').text), | |
subject_end = int(hsp_elem.find('Hsp_hit-to').text), | |
bitscore = float(hsp_elem.find('Hsp_bit-score').text), | |
) | |
# Ensure sequence start coordinate is always less than sequence end | |
# coordinate (which they won't be for sequences on minus strand). | |
for seq_type in ('query', 'subject'): | |
if getattr(hsp, '%s_frame' % seq_type) < 0: | |
start_key = '%s_start' % seq_type | |
end_key = '%s_end' % seq_type | |
start = getattr(hsp, start_key) | |
end = getattr(hsp, end_key) | |
if seq_type == 'query': | |
length = query.query_length | |
elif seq_type == 'subject': | |
length = hit.subject_length | |
if start > end: | |
tmp = start | |
start = end | |
end = tmp | |
hsp = hsp._replace(**{ | |
start_key: length - end + 1, | |
end_key: length - start + 1, | |
}) | |
key = self._determine_strand(hsp.query_frame) + \ | |
self._determine_strand(hsp.subject_frame) | |
hit.hsps[key].append(hsp) | |
query.hits.append(hit) | |
queries.append(query) | |
return queries | |
class Scorer(object): | |
def print_query_scores(self, queries): | |
for query in queries: | |
best_query_proportion = 0 | |
for hit in query.hits: | |
for strand_combo, segregated_hsps in hit.hsps.items(): | |
query_proportion = self._calc_query_proportion(segregated_hsps, query.query_length) | |
if query_proportion > best_query_proportion: | |
best_query_proportion = query_proportion | |
print(best_query_proportion) | |
class CompatibleHspScorer(Scorer): | |
def _calc_query_proportion(self, hsps, query_len): | |
#hsp_set = DumbHspSet(hsps) | |
hsp_set = SmartHspSet(hsps) | |
score, soln = hsp_set.find_best_combo() | |
hsps_size = sum([h.query_end - h.query_start + 1 for h in soln]) | |
return hsps_size / query_len | |
class HspUnionScorer(Scorer): | |
def _calc_query_proportion(self, hsps, query_len): | |
intervals = [(hsp.query_start, hsp.query_end) for hsp in hsps] | |
interval_union = self._compute_interval_union(intervals) | |
size_sum = self._compute_interval_size_sum(interval_union) | |
return size_sum / query_len | |
def _compute_interval_union(self, intervals): | |
''' | |
Compute union of overlapping intervals. Assume both endpoints are inclusive. | |
''' | |
points = [] | |
for start, end in intervals: | |
points.append((start, 'start')) | |
points.append((end, 'end')) | |
points.sort(key=lambda p: p[0]) | |
# Collapse overlapping intervals. | |
open_intervals = 0 | |
current_open_start = None | |
interval_union = [] | |
for coord, ptype in points: | |
if ptype == 'start': | |
if open_intervals == 0: | |
current_open_start = coord | |
open_intervals += 1 | |
elif ptype == 'end': | |
open_intervals -= 1 | |
if open_intervals == 0: | |
interval_union.append((current_open_start, coord)) | |
else: | |
raise Exception('Unknown point type: %s' % ptype) | |
# Merge adjacent intervals. | |
merged_union = [] | |
open_start, last_end = interval_union[0] | |
for start, end in interval_union[1:]: | |
assert start >= last_end | |
# Last interval and current interval are not immediately adjacent. | |
if start > last_end + 1: | |
merged_union.append((open_start, last_end)) | |
open_start = start | |
last_end = end | |
merged_union.append((open_start, last_end)) | |
return merged_union | |
def _compute_interval_size_sum(self, intervals): | |
interval_sum = 0 | |
for start, end in intervals: | |
interval_sum += end - start + 1 | |
return interval_sum | |
class TestSmartHspSet(unittest.TestCase): | |
def test_simple(self): | |
hsps = [ | |
Hsp(query_start=119, query_end=249, subject_start=1157, subject_end=1287, bitscore=243.031, subject_frame=1, query_frame=1, alignment_length=5), | |
Hsp(query_start=244, query_end=345, subject_start=845, subject_end=946, bitscore=189.479, subject_frame=1, query_frame=1, alignment_length=5), | |
Hsp(query_start=17, query_end=118, subject_start=1901, subject_end=2002, bitscore=189.479, subject_frame=1, query_frame=1, alignment_length=5) | |
] | |
hsp_set = SmartHspSet(hsps) | |
score, soln = hsp_set.find_best_combo() | |
self.assertAlmostEqual(243.031, score) | |
self.assertEqual([hsps[0]], soln) | |
def test_simple_two(self): | |
hsps = [ | |
Hsp(query_start=346, query_end=854, subject_start=9427, subject_end=9935, bitscore=470.169, subject_frame=1, query_frame=1, alignment_length=5), | |
Hsp(query_start=1, query_end=339, subject_start=231024, subject_end=231362, bitscore=178.399, subject_frame=1, query_frame=1, alignment_length=5) | |
] | |
hsp_set = SmartHspSet(hsps) | |
score, soln = hsp_set.find_best_combo() | |
self.assertAlmostEqual(470.169, score) | |
self.assertEqual([hsps[0]], soln) | |
def _test_against_blast_xml(self, xml_filename, strand_pair, hsp_len, expected_score, expected_soln): | |
xml_path = os.path.join(os.path.dirname(__file__), 'test-cases', xml_filename) | |
queries = BlastParser().parse(xml_path) | |
self.assertEqual(1, len(queries)) | |
query = queries[0] | |
self.assertEqual(1, len(query.hits)) | |
hit = query.hits[0] | |
self.assertEqual(1, len(hit.hsps.keys())) | |
hsps = hit.hsps[strand_pair] | |
self.assertEqual(hsp_len, len(hsps)) | |
hsp_set = SmartHspSet(hsps) | |
score, soln = hsp_set.find_best_combo() | |
self.assertAlmostEqual(expected_score, score) | |
self.assertEqual(expected_soln, soln) | |
def test_reverse_complement(self): | |
self._test_against_blast_xml( | |
xml_filename = 'reverse_complement.blast.xml', | |
strand_pair = '+-', | |
hsp_len = 2, | |
expected_score = 505.562, | |
expected_soln = [ Hsp( | |
query_start=5, | |
query_end=8, | |
query_frame=1, | |
subject_start=22, | |
subject_end=22, | |
subject_frame=-1, | |
alignment_length=275, | |
bitscore=505.562) ]) | |
def test_problematic(self): | |
self._test_against_blast_xml( | |
xml_filename = 'problematic.blast.xml', | |
strand_pair = '++', | |
hsp_len = 4, | |
expected_score = 1699.312, | |
expected_soln = [ Hsp( | |
query_start=1631, | |
query_end=1949, | |
query_frame=1, | |
subject_start=1, | |
subject_end=319, | |
subject_frame=1, | |
alignment_length=275, | |
bitscore=573.582 | |
), | |
Hsp( | |
query_start=2687, | |
query_end=3406, | |
query_frame=1, | |
subject_start=373, | |
subject_end=1075, | |
subject_frame=1, | |
alignment_length=275, | |
bitscore=1125.73 | |
) ]) | |
class TestHspUnionScorer(unittest.TestCase): | |
def test_basic(self): | |
cases = ( | |
# Basic example | |
( | |
[(1, 16), (7, 15), (5, 10), (13, 16), (19, 21), (19, 20)], | |
[(1, 16), (19, 21)], | |
19 | |
), | |
# Correctly handle endpoints, assuming start-inclusive, end-exclusive | |
( | |
[(1, 15), (15, 20)], | |
[(1, 20)], | |
20 | |
), | |
# Correctly handle endpoints, assuming start-inclusive, end-inclusive | |
( | |
[(1, 15), (16, 20)], | |
[(1, 20)], | |
20 | |
), | |
# Ensure overlapping interavls handled properly | |
( | |
[(1, 15), (14, 20)], | |
[(1, 20)], | |
20 | |
) | |
) | |
for intervals, expected_union, expected_size_sum in cases: | |
scorer = HspUnionScorer() | |
result_union = scorer._compute_interval_union(intervals) | |
result_size_sum = scorer._compute_interval_size_sum(result_union) | |
self.assertEqual(expected_union, result_union) | |
self.assertEqual(expected_size_sum, result_size_sum) | |
def test_smart_vs_dumb_algorithm(queries): | |
decimal_places = 7 | |
for query in queries: | |
for hit in query.hits: | |
for strand_combo, segregated_hsps in hit.hsps.items(): | |
# Dumb algorithm takes exponential time (given 2^n combinations), so avoid | |
# computing more than 2^20 combos. | |
if len(segregated_hsps) > 20: | |
continue | |
smart_hsp_set = SmartHspSet(segregated_hsps) | |
dumb_hsp_set = DumbHspSet(segregated_hsps) | |
smart_score, smart_soln = hsp_set.find_best_combo() | |
dumb_score, dumb_soln = dumb_hsp_set.find_best_combo() | |
if round(dumb_score - smart_score, decimal_places) != 0: | |
from pprint import pprint | |
pprint(segregated_hsps, stream=sys.stderr) | |
raise Exception('Scores unequal (smart=%s, dumb=%s)' % (smart_score, dumb_score)) | |
def main(): | |
class RunType(Enum): | |
normal = 1 | |
compare_dp_vs_dumb = 2 | |
test_cases = 3 | |
run_type = RunType.normal | |
#run_type = RunType.compare_dp_vs_dumb | |
#run_type = RunType.test_cases | |
if run_type == RunType.normal: | |
queries = BlastParser().parse(sys.stdin) | |
#scorer = HspUnionScorer() | |
scorer = CompatibleHspScorer() | |
scorer.print_query_scores(queries) | |
elif run_type == RunType.compare_dp_vs_dumb: | |
queries = BlastParser().parse(sys.stdin) | |
test_smart_vs_dumb_algorithm(queries) | |
elif run_type == RunType.test_cases: | |
unittest.main() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment