Skip to content

Instantly share code, notes, and snippets.

@alexgarel
Created January 31, 2018 19:43
Show Gist options
  • Save alexgarel/cb727d413532ba454ccad7d01e88e449 to your computer and use it in GitHub Desktop.
Save alexgarel/cb727d413532ba454ccad7d01e88e449 to your computer and use it in GitHub Desktop.
Filtering values from a list in a numpy array (somewhat like isin() of pandas)
import numpy as np
def np_arg_in(a, values, sorter=None):
"""find indices of `a` containing one of values
:param a: numpy array to search in
:param values: numpy array of values to search
:sorter: optional array of integer indices that sort array a into ascending order
"""
if not isinstance(values, np.ndarray):
values = np.array(values)
if not (values.shape[0] and a.shape[0]):
return np.array([])
# find insertion points
nearest_indices = np.searchsorted(a, values, sorter=sorter)
# avoid out of range
_mask_out = nearest_indices < a.shape[0]
nearest_indices = nearest_indices[_mask_out]
values = values[_mask_out]
if not nearest_indices.shape[0]:
return np.array([])
if sorter is not None:
nearest_indices = sorter[nearest_indices]
# keep only values really matching
_mask = a[nearest_indices] == values
return nearest_indices[_mask]
## some tests
class TestUtils(TestCase):
from unittest import TestCase
def test_np_arg_in(self):
X = np.array([1, 5, 3, 4])
sorter = np.argsort(X)
self.assertEqual(
list(np_arg_in(X, [], sorter=sorter)),
[],
)
self.assertEqual(
list(np_arg_in(X, X, sorter=sorter)),
[0, 1, 2, 3],
)
self.assertEqual(
list(np_arg_in(X, [0, 5, 2, 3, 3], sorter=sorter)),
[1, 2, 2],
)
# without sorter
sortedX = np.array([1, 3, 4, 5])
self.assertEqual(
list(np_arg_in(sortedX, [0, 5, 2, 3, 3])),
[3, 1, 1],
)
def test_np_arg_in_out_bound(self):
X = np.array([1, 5, 3, 4])
sorter = np.argsort(X)
self.assertEqual(
list(np_arg_in(X, [-1, 4, 10], sorter=sorter)),
[3],
)
self.assertEqual(
list(np_arg_in(X, [-1, 10], sorter=sorter)),
[],
)
def test_np_arg_in_empty_array(self):
X = np.array([])
self.assertEqual(
list(np_arg_in(X, [1, 2])),
[],
)
self.assertEqual(
list(np_arg_in(X, [1, 2], sorter=X)),
[],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment