Created
January 23, 2022 12:32
-
-
Save ap--/6df60657fb8cfb958482575d24f19eba to your computer and use it in GitHub Desktop.
simple merge intervals function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import itertools | |
def merge_intervals( | |
intervals: list[tuple[int, int]], | |
*, | |
merge_adjacent: bool = False, | |
inplace: bool = False, | |
) -> list[tuple[int, int]]: | |
"""merges overlapping intervals""" | |
assert all(x0 < x1 for x0, x1 in intervals), "all intervals must be valid" | |
if len(intervals) < 2: | |
if inplace: | |
return intervals | |
else: | |
return intervals.copy() | |
if inplace: | |
intervals.sort() | |
sorted_intervals = itertools.islice(intervals, 1, None) | |
out = intervals | |
else: | |
sorted_intervals = sorted(intervals) | |
out: list[tuple[int, int]] = [sorted_intervals.pop(0)] | |
def overlaps(a: tuple[int, int], b: tuple[int, int]) -> bool: | |
# [a0] [a1] ... [b0] [b1] | |
# [a0] [b0] ... [a1] [b1] | |
# [a0] [b0] ... [b1] [a1] | |
a0, a1 = a | |
b0, b1 = b | |
assert a0 <= b0, "guaranteed by sorting intervals" | |
if a1 < b0: | |
return False | |
elif a1 == b0: | |
return merge_adjacent | |
else: | |
return True | |
# integer tuples sort correctly --> (1, 5), (1, 3) ==> (1, 3), (1, 5) | |
current_idx = 0 | |
for interval in sorted_intervals: | |
last_visited = out[current_idx] | |
if overlaps(last_visited, interval): | |
out[current_idx] = (last_visited[0], max(last_visited[1], interval[1])) | |
else: | |
current_idx += 1 | |
if inplace: | |
out[current_idx] = interval | |
else: | |
out.append(interval) | |
if inplace: | |
del out[current_idx+1:] | |
return out | |
# === tests =========================================================== | |
import pytest | |
@pytest.mark.parametrize( | |
"intervals,output", [ | |
([], []), | |
([(1, 2)], [(1, 2)]), | |
([(1, 2), (3, 4)], [(1, 2), (3, 4)]), | |
([(1, 4), (3, 4)], [(1, 4)]), | |
([(10, 12), (1, 20), (2, 30)], [(1, 30)]), | |
([(1, 2), (2, 4)], [(1, 2), (2, 4)]), | |
], ids=lambda x: repr(x) | |
) | |
def test_merging_intervals(intervals, output): | |
assert merge_intervals(intervals) == output | |
@pytest.mark.parametrize( | |
"intervals,merge_adjacent,output", [ | |
([(1, 3), (3, 5)], True, [(1, 5)]), | |
([(1, 3), (3, 5)], False, [(1, 3), (3, 5)]) | |
] | |
) | |
def test_merging_adjacent(intervals, merge_adjacent, output): | |
assert merge_intervals(intervals, merge_adjacent=merge_adjacent) == output | |
@pytest.mark.parametrize( | |
"intervals,output", [ | |
([(1, 3), (3, 5)], [(1, 3), (3, 5)]), | |
([(3, 5), (1, 3)], [(1, 3), (3, 5)]), | |
([(10, 12), (1, 20), (2, 30)], [(1, 30)]), | |
] | |
) | |
def test_merging_inplace(intervals, output): | |
y = merge_intervals(intervals, inplace=True) | |
assert intervals is y | |
assert intervals == output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment