Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Python fast sorted list intersection
import bisect
def bisect_index(arr, start, end, x):
i = bisect.bisect_left(arr, x, lo=start, hi=end)
if i != end and arr[i] == x:
return i
return -1
def exponential_search(arr, start, x):
if x == arr[start]:
return 0
i = start + 1
while i < len(arr) and arr[i] <= x:
i = i * 2
return bisect_index(arr, i // 2, min(i, len(arr)), x)
def compute_intersection_list(l1, l2):
# find B, the smaller list
B = l1 if len(l1) < len(l2) else l2
A = l2 if l1 is B else l1
# run the algorithm described at:
# https://stackoverflow.com/a/40538162/145349
i = 0
j = 0
intersection_list = []
for i, x in enumerate(B):
j = exponential_search(A, j, x)
if j != -1:
intersection_list.append(x)
else:
j += 1
return intersection_list
# test
l1 = [1, 3, 4, 6, 7, 8, 9, 10]
l2 = [0, 2, 3, 6, 7, 9]
assert compute_intersection_list(l1, l2) == sorted(set(l1) & set(l2))
@ismael-elatifi

This comment has been minimized.

Copy link

@ismael-elatifi ismael-elatifi commented Nov 6, 2020

Here is a simpler implementation that uses bisect search for both lists to advance pointers.
i is used to iterate on A (the smaller list here) and j to iterate on B and unlike your implementation both indices can do "exponential jumps" thanks to bisect.
Of course it also assumes both input lists are sorted and contain no duplicates.

import bisect

def compute_intersection_list(l1, l2):
    # find A, the smaller list
    A = l1 if len(l1) < len(l2) else l2
    B = l2 if l1 is A else l1

    i = 0
    j = 0
    intersection_list = []
    while i < len(A) and j < len(B):
        if A[i] == B[j]:
            intersection_list.append(A[i])
            i += 1
            j += 1
        elif A[i] < B[j]:
            i = bisect.bisect_left(A, B[j], lo=i+1)
        else:
            j = bisect.bisect_left(B, A[i], lo=j+1)
    return intersection_list


# test on many random cases
import random

MM = 100  # max value

for _ in range(10000):
    M1 = random.randint(0, MM)  # random max value
    N1 = random.randint(0, M1)  # random number of values
    M2 = random.randint(0, MM)  # random max value
    N2 = random.randint(0, M2)  # random number of values
    a = sorted(random.sample(range(M1), N1))  # sampling without replacement to have no duplicates
    b = sorted(random.sample(range(M2), N2))
    assert compute_intersection_list(a, b) == sorted(set(a).intersection(b))
@fjsj

This comment has been minimized.

Copy link
Owner Author

@fjsj fjsj commented Nov 6, 2020

Nice, thanks!
Leaving a reference here, for numpy arrays, especially sorted ones, intersect1d has a really fast implementation: https://numpy.org/doc/stable/reference/generated/numpy.intersect1d.html

It's simply:

    mask = aux[1:] == aux[:-1]
    int1d = aux[:-1][mask]

https://github.com/numpy/numpy/blob/92ebe1e9a6aeb47a881a1226b08218175776f9ea/numpy/lib/arraysetops.py#L429-L430

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.