Skip to content

Instantly share code, notes, and snippets.

@jwintersinger
Created July 14, 2014 20:58
Show Gist options
  • Save jwintersinger/d5cd37660dd878462823 to your computer and use it in GitHub Desktop.
Save jwintersinger/d5cd37660dd878462823 to your computer and use it in GitHub Desktop.
Multiple methods for finding best set of mutually compatible BLAST hits
#!/usr/bin/env python3
'''
For each query listed in a BLAST XML results file, determine how much of the
query sequence is covered by the "best" hit. The best hit is deemed to be the
one in which the summed bitscores of its compatible HSPs is highest. Mutually
compatible HSPs are deemed to be ones that neither overlap nor "cross" each
other.
Usage:
cat blast_results.xml | find-best-hit.py
Output:
Values in range [0, 1], with one per line. The nth such value represents the
fraction of query n covered by the mutually compatible HSPs composing its
best hit.
'''
import sys
import xml.etree.ElementTree as ET
from collections import defaultdict, namedtuple
import json
from functools import lru_cache
from enum import Enum
import unittest
import os
class Path(Enum):
without_self = 1
with_self = 2
class BaseHspSet(object):
def __init__(self, hsps):
self._hsps = sorted(hsps, key = lambda a: a.query_start)
class DumbHspSet(BaseHspSet):
def find_best_combo(self):
highest_score = 0
best_combo = None
for valid_combo in self._compute_hsp_combos(self._hsps):
score = sum([i.bitscore for i in valid_combo])
if score > highest_score:
highest_score = score
best_combo = valid_combo
return (highest_score, best_combo)
def _compute_hsp_combos(self, L):
sorted_hsps = sorted(L, key = lambda a: a.query_start)
# Only look at subset of hits to make this computationally tractable.
return self._combo_r([], sorted_hsps)
def _combo_r(self, base, remaining):
if len(remaining) == 0:
return [base]
first = remaining[0]
rest = remaining[1:]
if len(base) > 0:
base_query_end = base[-1].query_end
base_subject_end = base[-1].subject_end
else:
base_query_end = 0
base_subject_end = 0
if first.query_start > base_query_end and first.subject_start > base_subject_end:
return self._combo_r(base + [first], rest) + self._combo_r(base, rest)
else:
return self._combo_r(base, rest)
class SmartHspSet(BaseHspSet):
# Segment HSPs into smaller, independent sets. No performance benefit appears
# in practice -- using segment takes almost exactly as long as not using it
# -- and so don't bother.
def _segment(self):
'''
Return partition of self._hsps, such that no element (i.e., subset of
self._hsps) contains an HSP that overlaps another HSP from a different
element. This is useful because the problem of finding the highest-scoring
subset of HSPs can then be solved independently for each subset forming the
partition without impairing the answer's optimality.
'''
segments = []
i = 0
n = len(self._hsps)
while i < n:
candidate = self._hsps[i]
max_query_end = candidate.query_end
max_subject_end = candidate.subject_end
conflicting_query_end = candidate.query_end
conflicting_subject_end = candidate.subject_end
last_conflicting = i
for j in range(i + 1, n):
other = self._hsps[j]
max_query_end = max(max_query_end, other.query_end)
max_subject_end = max(max_subject_end, other.subject_end)
if other.query_start <= conflicting_query_end \
or other.subject_start <= conflicting_subject_end:
conflicting_query_end = max_query_end
conflicting_subject_end = max_subject_end
last_conflicting = j
next_compatible = last_conflicting + 1
segments.append(tuple(range(i, next_compatible)))
i = next_compatible
return segments
@lru_cache(maxsize=None)
def _find_compatible(self, target, indices):
'''Find all hsps that don't overlap or cross `target`. Such HSPs are not
guaranteed to be mutually compatible, but are guaranteed only to be
compatible with `target`.'''
compatible = []
target_hsp = self._hsps[target]
for index in indices:
# Don't define target as overlapping itself.
if index == target:
continue
test_hsp = self._hsps[index]
if target_hsp.query_start <= test_hsp.query_start:
first, second = target_hsp, test_hsp
else:
first, second = test_hsp, target_hsp
overlap = second.query_start <= first.query_end
overlap = overlap or second.subject_start <= first.subject_end
if not overlap:
compatible.append(index)
return tuple(compatible)
def _reconstruct_soln(self, indices):
return self._reconstruct_soln_r(indices, [])
def _reconstruct_soln_r(self, indices, soln):
if len(indices) <= 1:
return indices
indices = tuple(indices)
if self._combo_soln[indices] == Path.without_self:
return self._reconstruct_soln_r(indices[:-1], soln)
else:
last_index = indices[-1]
compatible = self._find_compatible(last_index, indices)
soln = self._reconstruct_soln_r(compatible, soln)
return soln + (last_index,)
def find_best_combo(self):
indices = range(len(self._hsps))
self._best_values = {}
self._combo_soln = {}
score = self._find_best_combo_r(tuple(indices))
soln = self._reconstruct_soln(indices)
soln = [self._hsps[i] for i in soln]
return (score, soln)
@lru_cache(maxsize=None)
def _find_best_combo_r(self, indices):
# Must check for len == 0, which will occur if we're looking at only two
# hsps that are incompatible with each other. When looking at the
# second, the call self._find_best_combo_r(compatible) will recurse on an
# empty tuple (since "compatible" doesn't include the self, and it won't
# include the preceding (incompatible) hsp).
if len(indices) == 0:
return 0
if len(indices) == 1:
self._combo_soln[indices] = Path.with_self
return self._hsps[indices[0]].bitscore
last_index = indices[-1]
without_self_indices = indices[:-1]
# This order retains sort order, since indices[-1] will be after all other
# elements in indices.
compatible = self._find_compatible(last_index, indices)
best_without_self = self._find_best_combo_r(without_self_indices)
best_with_self = self._find_best_combo_r(compatible) + self._hsps[indices[-1]].bitscore
if best_without_self >= best_with_self:
self._best_values[indices] = best_without_self
self._combo_soln[indices] = Path.without_self
else:
self._best_values[indices] = best_with_self
self._combo_soln[indices] = Path.with_self
return self._best_values[indices]
Query = namedtuple('Query', [
'query_def',
'query_length',
'hits',
])
Hit = namedtuple('Hit', [
'subject_length',
'subject_def',
'hsps',
])
Hsp = namedtuple('Hsp', [
'query_start',
'query_end',
'query_frame',
'subject_start',
'subject_end',
'subject_frame',
'alignment_length',
'bitscore',
])
class BlastParser(object):
def _determine_strand(self, frame):
if frame > 0:
return '+'
elif frame < 0:
return '-'
else:
return '.'
def parse(self, blast_xml):
tree = ET.parse(blast_xml)
root = tree.getroot()
queries = []
for iteration_elem in root.find('BlastOutput_iterations').iter('Iteration'):
query = Query(
query_length = int(iteration_elem.find('Iteration_query-len').text),
query_def = iteration_elem.find('Iteration_query-def').text,
hits = []
)
hit_elems = iteration_elem.find('Iteration_hits')
if hit_elems is None:
continue
for hit_elem in hit_elems.iter('Hit'):
hit = Hit(
subject_length = int(hit_elem.find('Hit_len').text),
subject_def = hit_elem.find('Hit_def').text,
hsps = defaultdict(lambda: []),
)
for hsp_elem in hit_elem.find('Hit_hsps').iter('Hsp'):
hsp = Hsp(
query_frame = int(hsp_elem.find('Hsp_query-frame').text),
subject_frame = int(hsp_elem.find('Hsp_hit-frame').text),
alignment_length = int(hsp_elem.find('Hsp_align-len').text),
query_start = int(hsp_elem.find('Hsp_query-from').text),
query_end = int(hsp_elem.find('Hsp_query-to').text),
subject_start = int(hsp_elem.find('Hsp_hit-from').text),
subject_end = int(hsp_elem.find('Hsp_hit-to').text),
bitscore = float(hsp_elem.find('Hsp_bit-score').text),
)
# Ensure sequence start coordinate is always less than sequence end
# coordinate (which they won't be for sequences on minus strand).
for seq_type in ('query', 'subject'):
if getattr(hsp, '%s_frame' % seq_type) < 0:
start_key = '%s_start' % seq_type
end_key = '%s_end' % seq_type
start = getattr(hsp, start_key)
end = getattr(hsp, end_key)
if seq_type == 'query':
length = query.query_length
elif seq_type == 'subject':
length = hit.subject_length
if start > end:
tmp = start
start = end
end = tmp
hsp = hsp._replace(**{
start_key: length - end + 1,
end_key: length - start + 1,
})
key = self._determine_strand(hsp.query_frame) + \
self._determine_strand(hsp.subject_frame)
hit.hsps[key].append(hsp)
query.hits.append(hit)
queries.append(query)
return queries
class Scorer(object):
def print_query_scores(self, queries):
for query in queries:
best_query_proportion = 0
for hit in query.hits:
for strand_combo, segregated_hsps in hit.hsps.items():
query_proportion = self._calc_query_proportion(segregated_hsps, query.query_length)
if query_proportion > best_query_proportion:
best_query_proportion = query_proportion
print(best_query_proportion)
class CompatibleHspScorer(Scorer):
def _calc_query_proportion(self, hsps, query_len):
#hsp_set = DumbHspSet(hsps)
hsp_set = SmartHspSet(hsps)
score, soln = hsp_set.find_best_combo()
hsps_size = sum([h.query_end - h.query_start + 1 for h in soln])
return hsps_size / query_len
class HspUnionScorer(Scorer):
def _calc_query_proportion(self, hsps, query_len):
intervals = [(hsp.query_start, hsp.query_end) for hsp in hsps]
interval_union = self._compute_interval_union(intervals)
size_sum = self._compute_interval_size_sum(interval_union)
return size_sum / query_len
def _compute_interval_union(self, intervals):
'''
Compute union of overlapping intervals. Assume both endpoints are inclusive.
'''
points = []
for start, end in intervals:
points.append((start, 'start'))
points.append((end, 'end'))
points.sort(key=lambda p: p[0])
# Collapse overlapping intervals.
open_intervals = 0
current_open_start = None
interval_union = []
for coord, ptype in points:
if ptype == 'start':
if open_intervals == 0:
current_open_start = coord
open_intervals += 1
elif ptype == 'end':
open_intervals -= 1
if open_intervals == 0:
interval_union.append((current_open_start, coord))
else:
raise Exception('Unknown point type: %s' % ptype)
# Merge adjacent intervals.
merged_union = []
open_start, last_end = interval_union[0]
for start, end in interval_union[1:]:
assert start >= last_end
# Last interval and current interval are not immediately adjacent.
if start > last_end + 1:
merged_union.append((open_start, last_end))
open_start = start
last_end = end
merged_union.append((open_start, last_end))
return merged_union
def _compute_interval_size_sum(self, intervals):
interval_sum = 0
for start, end in intervals:
interval_sum += end - start + 1
return interval_sum
class TestSmartHspSet(unittest.TestCase):
def test_simple(self):
hsps = [
Hsp(query_start=119, query_end=249, subject_start=1157, subject_end=1287, bitscore=243.031, subject_frame=1, query_frame=1, alignment_length=5),
Hsp(query_start=244, query_end=345, subject_start=845, subject_end=946, bitscore=189.479, subject_frame=1, query_frame=1, alignment_length=5),
Hsp(query_start=17, query_end=118, subject_start=1901, subject_end=2002, bitscore=189.479, subject_frame=1, query_frame=1, alignment_length=5)
]
hsp_set = SmartHspSet(hsps)
score, soln = hsp_set.find_best_combo()
self.assertAlmostEqual(243.031, score)
self.assertEqual([hsps[0]], soln)
def test_simple_two(self):
hsps = [
Hsp(query_start=346, query_end=854, subject_start=9427, subject_end=9935, bitscore=470.169, subject_frame=1, query_frame=1, alignment_length=5),
Hsp(query_start=1, query_end=339, subject_start=231024, subject_end=231362, bitscore=178.399, subject_frame=1, query_frame=1, alignment_length=5)
]
hsp_set = SmartHspSet(hsps)
score, soln = hsp_set.find_best_combo()
self.assertAlmostEqual(470.169, score)
self.assertEqual([hsps[0]], soln)
def _test_against_blast_xml(self, xml_filename, strand_pair, hsp_len, expected_score, expected_soln):
xml_path = os.path.join(os.path.dirname(__file__), 'test-cases', xml_filename)
queries = BlastParser().parse(xml_path)
self.assertEqual(1, len(queries))
query = queries[0]
self.assertEqual(1, len(query.hits))
hit = query.hits[0]
self.assertEqual(1, len(hit.hsps.keys()))
hsps = hit.hsps[strand_pair]
self.assertEqual(hsp_len, len(hsps))
hsp_set = SmartHspSet(hsps)
score, soln = hsp_set.find_best_combo()
self.assertAlmostEqual(expected_score, score)
self.assertEqual(expected_soln, soln)
def test_reverse_complement(self):
self._test_against_blast_xml(
xml_filename = 'reverse_complement.blast.xml',
strand_pair = '+-',
hsp_len = 2,
expected_score = 505.562,
expected_soln = [ Hsp(
query_start=5,
query_end=8,
query_frame=1,
subject_start=22,
subject_end=22,
subject_frame=-1,
alignment_length=275,
bitscore=505.562) ])
def test_problematic(self):
self._test_against_blast_xml(
xml_filename = 'problematic.blast.xml',
strand_pair = '++',
hsp_len = 4,
expected_score = 1699.312,
expected_soln = [ Hsp(
query_start=1631,
query_end=1949,
query_frame=1,
subject_start=1,
subject_end=319,
subject_frame=1,
alignment_length=275,
bitscore=573.582
),
Hsp(
query_start=2687,
query_end=3406,
query_frame=1,
subject_start=373,
subject_end=1075,
subject_frame=1,
alignment_length=275,
bitscore=1125.73
) ])
class TestHspUnionScorer(unittest.TestCase):
def test_basic(self):
cases = (
# Basic example
(
[(1, 16), (7, 15), (5, 10), (13, 16), (19, 21), (19, 20)],
[(1, 16), (19, 21)],
19
),
# Correctly handle endpoints, assuming start-inclusive, end-exclusive
(
[(1, 15), (15, 20)],
[(1, 20)],
20
),
# Correctly handle endpoints, assuming start-inclusive, end-inclusive
(
[(1, 15), (16, 20)],
[(1, 20)],
20
),
# Ensure overlapping interavls handled properly
(
[(1, 15), (14, 20)],
[(1, 20)],
20
)
)
for intervals, expected_union, expected_size_sum in cases:
scorer = HspUnionScorer()
result_union = scorer._compute_interval_union(intervals)
result_size_sum = scorer._compute_interval_size_sum(result_union)
self.assertEqual(expected_union, result_union)
self.assertEqual(expected_size_sum, result_size_sum)
def test_smart_vs_dumb_algorithm(queries):
decimal_places = 7
for query in queries:
for hit in query.hits:
for strand_combo, segregated_hsps in hit.hsps.items():
# Dumb algorithm takes exponential time (given 2^n combinations), so avoid
# computing more than 2^20 combos.
if len(segregated_hsps) > 20:
continue
smart_hsp_set = SmartHspSet(segregated_hsps)
dumb_hsp_set = DumbHspSet(segregated_hsps)
smart_score, smart_soln = hsp_set.find_best_combo()
dumb_score, dumb_soln = dumb_hsp_set.find_best_combo()
if round(dumb_score - smart_score, decimal_places) != 0:
from pprint import pprint
pprint(segregated_hsps, stream=sys.stderr)
raise Exception('Scores unequal (smart=%s, dumb=%s)' % (smart_score, dumb_score))
def main():
class RunType(Enum):
normal = 1
compare_dp_vs_dumb = 2
test_cases = 3
run_type = RunType.normal
#run_type = RunType.compare_dp_vs_dumb
#run_type = RunType.test_cases
if run_type == RunType.normal:
queries = BlastParser().parse(sys.stdin)
#scorer = HspUnionScorer()
scorer = CompatibleHspScorer()
scorer.print_query_scores(queries)
elif run_type == RunType.compare_dp_vs_dumb:
queries = BlastParser().parse(sys.stdin)
test_smart_vs_dumb_algorithm(queries)
elif run_type == RunType.test_cases:
unittest.main()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment