Instantly share code, notes, and snippets.

# jwintersinger/find-best-hit.py Created Jul 14, 2014

Multiple methods for finding best set of mutually compatible BLAST hits
 #!/usr/bin/env python3 ''' For each query listed in a BLAST XML results file, determine how much of the query sequence is covered by the "best" hit. The best hit is deemed to be the one in which the summed bitscores of its compatible HSPs is highest. Mutually compatible HSPs are deemed to be ones that neither overlap nor "cross" each other. Usage: cat blast_results.xml | find-best-hit.py Output: Values in range [0, 1], with one per line. The nth such value represents the fraction of query n covered by the mutually compatible HSPs composing its best hit. ''' import sys import xml.etree.ElementTree as ET from collections import defaultdict, namedtuple import json from functools import lru_cache from enum import Enum import unittest import os class Path(Enum): without_self = 1 with_self = 2 class BaseHspSet(object): def __init__(self, hsps): self._hsps = sorted(hsps, key = lambda a: a.query_start) class DumbHspSet(BaseHspSet): def find_best_combo(self): highest_score = 0 best_combo = None for valid_combo in self._compute_hsp_combos(self._hsps): score = sum([i.bitscore for i in valid_combo]) if score > highest_score: highest_score = score best_combo = valid_combo return (highest_score, best_combo) def _compute_hsp_combos(self, L): sorted_hsps = sorted(L, key = lambda a: a.query_start) # Only look at subset of hits to make this computationally tractable. return self._combo_r([], sorted_hsps) def _combo_r(self, base, remaining): if len(remaining) == 0: return [base] first = remaining rest = remaining[1:] if len(base) > 0: base_query_end = base[-1].query_end base_subject_end = base[-1].subject_end else: base_query_end = 0 base_subject_end = 0 if first.query_start > base_query_end and first.subject_start > base_subject_end: return self._combo_r(base + [first], rest) + self._combo_r(base, rest) else: return self._combo_r(base, rest) class SmartHspSet(BaseHspSet): # Segment HSPs into smaller, independent sets. No performance benefit appears # in practice -- using segment takes almost exactly as long as not using it # -- and so don't bother. def _segment(self): ''' Return partition of self._hsps, such that no element (i.e., subset of self._hsps) contains an HSP that overlaps another HSP from a different element. This is useful because the problem of finding the highest-scoring subset of HSPs can then be solved independently for each subset forming the partition without impairing the answer's optimality. ''' segments = [] i = 0 n = len(self._hsps) while i < n: candidate = self._hsps[i] max_query_end = candidate.query_end max_subject_end = candidate.subject_end conflicting_query_end = candidate.query_end conflicting_subject_end = candidate.subject_end last_conflicting = i for j in range(i + 1, n): other = self._hsps[j] max_query_end = max(max_query_end, other.query_end) max_subject_end = max(max_subject_end, other.subject_end) if other.query_start <= conflicting_query_end \ or other.subject_start <= conflicting_subject_end: conflicting_query_end = max_query_end conflicting_subject_end = max_subject_end last_conflicting = j next_compatible = last_conflicting + 1 segments.append(tuple(range(i, next_compatible))) i = next_compatible return segments @lru_cache(maxsize=None) def _find_compatible(self, target, indices): '''Find all hsps that don't overlap or cross `target`. Such HSPs are not guaranteed to be mutually compatible, but are guaranteed only to be compatible with `target`.''' compatible = [] target_hsp = self._hsps[target] for index in indices: # Don't define target as overlapping itself. if index == target: continue test_hsp = self._hsps[index] if target_hsp.query_start <= test_hsp.query_start: first, second = target_hsp, test_hsp else: first, second = test_hsp, target_hsp overlap = second.query_start <= first.query_end overlap = overlap or second.subject_start <= first.subject_end if not overlap: compatible.append(index) return tuple(compatible) def _reconstruct_soln(self, indices): return self._reconstruct_soln_r(indices, []) def _reconstruct_soln_r(self, indices, soln): if len(indices) <= 1: return indices indices = tuple(indices) if self._combo_soln[indices] == Path.without_self: return self._reconstruct_soln_r(indices[:-1], soln) else: last_index = indices[-1] compatible = self._find_compatible(last_index, indices) soln = self._reconstruct_soln_r(compatible, soln) return soln + (last_index,) def find_best_combo(self): indices = range(len(self._hsps)) self._best_values = {} self._combo_soln = {} score = self._find_best_combo_r(tuple(indices)) soln = self._reconstruct_soln(indices) soln = [self._hsps[i] for i in soln] return (score, soln) @lru_cache(maxsize=None) def _find_best_combo_r(self, indices): # Must check for len == 0, which will occur if we're looking at only two # hsps that are incompatible with each other. When looking at the # second, the call self._find_best_combo_r(compatible) will recurse on an # empty tuple (since "compatible" doesn't include the self, and it won't # include the preceding (incompatible) hsp). if len(indices) == 0: return 0 if len(indices) == 1: self._combo_soln[indices] = Path.with_self return self._hsps[indices].bitscore last_index = indices[-1] without_self_indices = indices[:-1] # This order retains sort order, since indices[-1] will be after all other # elements in indices. compatible = self._find_compatible(last_index, indices) best_without_self = self._find_best_combo_r(without_self_indices) best_with_self = self._find_best_combo_r(compatible) + self._hsps[indices[-1]].bitscore if best_without_self >= best_with_self: self._best_values[indices] = best_without_self self._combo_soln[indices] = Path.without_self else: self._best_values[indices] = best_with_self self._combo_soln[indices] = Path.with_self return self._best_values[indices] Query = namedtuple('Query', [ 'query_def', 'query_length', 'hits', ]) Hit = namedtuple('Hit', [ 'subject_length', 'subject_def', 'hsps', ]) Hsp = namedtuple('Hsp', [ 'query_start', 'query_end', 'query_frame', 'subject_start', 'subject_end', 'subject_frame', 'alignment_length', 'bitscore', ]) class BlastParser(object): def _determine_strand(self, frame): if frame > 0: return '+' elif frame < 0: return '-' else: return '.' def parse(self, blast_xml): tree = ET.parse(blast_xml) root = tree.getroot() queries = [] for iteration_elem in root.find('BlastOutput_iterations').iter('Iteration'): query = Query( query_length = int(iteration_elem.find('Iteration_query-len').text), query_def = iteration_elem.find('Iteration_query-def').text, hits = [] ) hit_elems = iteration_elem.find('Iteration_hits') if hit_elems is None: continue for hit_elem in hit_elems.iter('Hit'): hit = Hit( subject_length = int(hit_elem.find('Hit_len').text), subject_def = hit_elem.find('Hit_def').text, hsps = defaultdict(lambda: []), ) for hsp_elem in hit_elem.find('Hit_hsps').iter('Hsp'): hsp = Hsp( query_frame = int(hsp_elem.find('Hsp_query-frame').text), subject_frame = int(hsp_elem.find('Hsp_hit-frame').text), alignment_length = int(hsp_elem.find('Hsp_align-len').text), query_start = int(hsp_elem.find('Hsp_query-from').text), query_end = int(hsp_elem.find('Hsp_query-to').text), subject_start = int(hsp_elem.find('Hsp_hit-from').text), subject_end = int(hsp_elem.find('Hsp_hit-to').text), bitscore = float(hsp_elem.find('Hsp_bit-score').text), ) # Ensure sequence start coordinate is always less than sequence end # coordinate (which they won't be for sequences on minus strand). for seq_type in ('query', 'subject'): if getattr(hsp, '%s_frame' % seq_type) < 0: start_key = '%s_start' % seq_type end_key = '%s_end' % seq_type start = getattr(hsp, start_key) end = getattr(hsp, end_key) if seq_type == 'query': length = query.query_length elif seq_type == 'subject': length = hit.subject_length if start > end: tmp = start start = end end = tmp hsp = hsp._replace(**{ start_key: length - end + 1, end_key: length - start + 1, }) key = self._determine_strand(hsp.query_frame) + \ self._determine_strand(hsp.subject_frame) hit.hsps[key].append(hsp) query.hits.append(hit) queries.append(query) return queries class Scorer(object): def print_query_scores(self, queries): for query in queries: best_query_proportion = 0 for hit in query.hits: for strand_combo, segregated_hsps in hit.hsps.items(): query_proportion = self._calc_query_proportion(segregated_hsps, query.query_length) if query_proportion > best_query_proportion: best_query_proportion = query_proportion print(best_query_proportion) class CompatibleHspScorer(Scorer): def _calc_query_proportion(self, hsps, query_len): #hsp_set = DumbHspSet(hsps) hsp_set = SmartHspSet(hsps) score, soln = hsp_set.find_best_combo() hsps_size = sum([h.query_end - h.query_start + 1 for h in soln]) return hsps_size / query_len class HspUnionScorer(Scorer): def _calc_query_proportion(self, hsps, query_len): intervals = [(hsp.query_start, hsp.query_end) for hsp in hsps] interval_union = self._compute_interval_union(intervals) size_sum = self._compute_interval_size_sum(interval_union) return size_sum / query_len def _compute_interval_union(self, intervals): ''' Compute union of overlapping intervals. Assume both endpoints are inclusive. ''' points = [] for start, end in intervals: points.append((start, 'start')) points.append((end, 'end')) points.sort(key=lambda p: p) # Collapse overlapping intervals. open_intervals = 0 current_open_start = None interval_union = [] for coord, ptype in points: if ptype == 'start': if open_intervals == 0: current_open_start = coord open_intervals += 1 elif ptype == 'end': open_intervals -= 1 if open_intervals == 0: interval_union.append((current_open_start, coord)) else: raise Exception('Unknown point type: %s' % ptype) # Merge adjacent intervals. merged_union = [] open_start, last_end = interval_union for start, end in interval_union[1:]: assert start >= last_end # Last interval and current interval are not immediately adjacent. if start > last_end + 1: merged_union.append((open_start, last_end)) open_start = start last_end = end merged_union.append((open_start, last_end)) return merged_union def _compute_interval_size_sum(self, intervals): interval_sum = 0 for start, end in intervals: interval_sum += end - start + 1 return interval_sum class TestSmartHspSet(unittest.TestCase): def test_simple(self): hsps = [ Hsp(query_start=119, query_end=249, subject_start=1157, subject_end=1287, bitscore=243.031, subject_frame=1, query_frame=1, alignment_length=5), Hsp(query_start=244, query_end=345, subject_start=845, subject_end=946, bitscore=189.479, subject_frame=1, query_frame=1, alignment_length=5), Hsp(query_start=17, query_end=118, subject_start=1901, subject_end=2002, bitscore=189.479, subject_frame=1, query_frame=1, alignment_length=5) ] hsp_set = SmartHspSet(hsps) score, soln = hsp_set.find_best_combo() self.assertAlmostEqual(243.031, score) self.assertEqual([hsps], soln) def test_simple_two(self): hsps = [ Hsp(query_start=346, query_end=854, subject_start=9427, subject_end=9935, bitscore=470.169, subject_frame=1, query_frame=1, alignment_length=5), Hsp(query_start=1, query_end=339, subject_start=231024, subject_end=231362, bitscore=178.399, subject_frame=1, query_frame=1, alignment_length=5) ] hsp_set = SmartHspSet(hsps) score, soln = hsp_set.find_best_combo() self.assertAlmostEqual(470.169, score) self.assertEqual([hsps], soln) def _test_against_blast_xml(self, xml_filename, strand_pair, hsp_len, expected_score, expected_soln): xml_path = os.path.join(os.path.dirname(__file__), 'test-cases', xml_filename) queries = BlastParser().parse(xml_path) self.assertEqual(1, len(queries)) query = queries self.assertEqual(1, len(query.hits)) hit = query.hits self.assertEqual(1, len(hit.hsps.keys())) hsps = hit.hsps[strand_pair] self.assertEqual(hsp_len, len(hsps)) hsp_set = SmartHspSet(hsps) score, soln = hsp_set.find_best_combo() self.assertAlmostEqual(expected_score, score) self.assertEqual(expected_soln, soln) def test_reverse_complement(self): self._test_against_blast_xml( xml_filename = 'reverse_complement.blast.xml', strand_pair = '+-', hsp_len = 2, expected_score = 505.562, expected_soln = [ Hsp( query_start=5, query_end=8, query_frame=1, subject_start=22, subject_end=22, subject_frame=-1, alignment_length=275, bitscore=505.562) ]) def test_problematic(self): self._test_against_blast_xml( xml_filename = 'problematic.blast.xml', strand_pair = '++', hsp_len = 4, expected_score = 1699.312, expected_soln = [ Hsp( query_start=1631, query_end=1949, query_frame=1, subject_start=1, subject_end=319, subject_frame=1, alignment_length=275, bitscore=573.582 ), Hsp( query_start=2687, query_end=3406, query_frame=1, subject_start=373, subject_end=1075, subject_frame=1, alignment_length=275, bitscore=1125.73 ) ]) class TestHspUnionScorer(unittest.TestCase): def test_basic(self): cases = ( # Basic example ( [(1, 16), (7, 15), (5, 10), (13, 16), (19, 21), (19, 20)], [(1, 16), (19, 21)], 19 ), # Correctly handle endpoints, assuming start-inclusive, end-exclusive ( [(1, 15), (15, 20)], [(1, 20)], 20 ), # Correctly handle endpoints, assuming start-inclusive, end-inclusive ( [(1, 15), (16, 20)], [(1, 20)], 20 ), # Ensure overlapping interavls handled properly ( [(1, 15), (14, 20)], [(1, 20)], 20 ) ) for intervals, expected_union, expected_size_sum in cases: scorer = HspUnionScorer() result_union = scorer._compute_interval_union(intervals) result_size_sum = scorer._compute_interval_size_sum(result_union) self.assertEqual(expected_union, result_union) self.assertEqual(expected_size_sum, result_size_sum) def test_smart_vs_dumb_algorithm(queries): decimal_places = 7 for query in queries: for hit in query.hits: for strand_combo, segregated_hsps in hit.hsps.items(): # Dumb algorithm takes exponential time (given 2^n combinations), so avoid # computing more than 2^20 combos. if len(segregated_hsps) > 20: continue smart_hsp_set = SmartHspSet(segregated_hsps) dumb_hsp_set = DumbHspSet(segregated_hsps) smart_score, smart_soln = hsp_set.find_best_combo() dumb_score, dumb_soln = dumb_hsp_set.find_best_combo() if round(dumb_score - smart_score, decimal_places) != 0: from pprint import pprint pprint(segregated_hsps, stream=sys.stderr) raise Exception('Scores unequal (smart=%s, dumb=%s)' % (smart_score, dumb_score)) def main(): class RunType(Enum): normal = 1 compare_dp_vs_dumb = 2 test_cases = 3 run_type = RunType.normal #run_type = RunType.compare_dp_vs_dumb #run_type = RunType.test_cases if run_type == RunType.normal: queries = BlastParser().parse(sys.stdin) #scorer = HspUnionScorer() scorer = CompatibleHspScorer() scorer.print_query_scores(queries) elif run_type == RunType.compare_dp_vs_dumb: queries = BlastParser().parse(sys.stdin) test_smart_vs_dumb_algorithm(queries) elif run_type == RunType.test_cases: unittest.main() if __name__ == '__main__': main()
to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.