Skip to content

Instantly share code, notes, and snippets.

@mjm522
Last active January 10, 2022 21:46
Show Gist options
  • Save mjm522/c263db0f2b535222d748f8dfa8d260da to your computer and use it in GitHub Desktop.
Save mjm522/c263db0f2b535222d748f8dfa8d260da to your computer and use it in GitHub Desktop.
Compare 2D arrays. Return list of indices containing the rows.
import numpy as np
def two_d_array_compare(array_a, array_b, thresh):
"""
this function does a pairwise comparision matching between array_b and array_a
it returns all the row numbers in array_a that has a similar entry in array_b
:param array_a: m x k
:param array_b: n x k
:thresh float difference
:return: row numbers of similar rows
"""
thresh = abs(thresh)
similar_indices = []
if array_b.ndim == 1:
array_b = array_b[None, :]
if array_a.ndim == 1:
array_a = array_a[None, :]
for row_b in array_b:
res = (row_b - thresh <= array_a) * (array_a <= row_b + thresh)
if res.shape[0] > 1:
res = np.prod(res.squeeze(), 1)
else:
res = np.prod(res.squeeze(), 0)
if np.any(res):
similar_indices += list(np.where(res)[0])
return list(set(similar_indices))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment