Skip to content

Instantly share code, notes, and snippets.

@agoose77
Last active August 2, 2023 11:12
Show Gist options
  • Save agoose77/313bf6e24146c9fd0f92b160b6cb0e5e to your computer and use it in GitHub Desktop.
Save agoose77/313bf6e24146c9fd0f92b160b6cb0e5e to your computer and use it in GitHub Desktop.
import awkward as ak
import numpy as np
import numba as nb
@nb.njit()
def _strings_in_strings_kernel(needle, haystack):
result = np.empty(len(needle), np.bool_)
for i, s in enumerate(needle):
for j in haystack:
if s == j:
break
else:
result[i] = False
continue
result[i] = True
return result
def is_in(needle, haystack):
needle_layout = ak.to_layout(needle, allow_record=True, allow_other=False)
haystack_layout = ak.to_layout(haystack, allow_record=False, allow_other=False)
if isinstance(needle_layout, ak.record.Record):
raise NotImplementedError("does not support ak.record.Record objects")
def is_in_1d_needle_all_string(needle_layout, haystack_layout):
out = _strings_in_strings_kernel(
ak.Array(needle_layout), ak.Array(haystack_layout)
)
return ak.to_layout(out)
def is_in_1d_needle_any_option(needle_layout, haystack_layout):
if needle_layout.is_option and not haystack_layout.is_option:
projected = is_in_1d_needle(needle_layout.project(), haystack_layout)
result = np.full(needle_layout.length, False)
result[needle_layout.mask_as_bool(True)] = projected.to_backend_array()
return ak.contents.NumpyArray(result)
elif not needle_layout.is_option and haystack_layout.is_option:
return is_in_1d_needle(needle_layout, haystack_layout.project())
else:
tmp_result = is_in_1d_needle(needle_layout, haystack_layout.project())
assert isinstance(tmp_result, ak.contents.NumpyArray)
if np.any(haystack_layout.mask_as_bool(True)):
tmp_result = ak.to_layout(
np.logical_or(
tmp_result.to_backend_array(), needle_layout.mask_as_bool(False)
)
)
return tmp_result
def is_in_1d_needle_any_unknown(needle_layout, haystack_layout):
return ak.to_layout(np.full(len(needle_layout), False))
def is_in_1d_needle_all_numpy(needle_layout, haystack_layout):
return ak.to_layout(
np.isin(
needle_layout.to_backend_array(), haystack_layout.to_backend_array()
)
)
def is_in_1d_needle(needle_layout, haystack_layout):
if needle_layout.is_option or haystack_layout.is_option:
return is_in_1d_needle_any_option(needle_layout, haystack_layout)
elif needle_layout.is_unknown or haystack_layout.is_unknown:
return is_in_1d_needle_any_unknown(needle_layout, haystack_layout)
elif (needle_layout.parameter("__array__") in {"string", "bytestring"}) and (
haystack_layout.parameter("__array__") in {"string", "bytestring"}
):
return is_in_1d_needle_all_string(needle_layout, haystack_layout)
elif needle_layout.is_numpy and haystack_layout.is_numpy:
return is_in_1d_needle_all_numpy(needle_layout, haystack_layout)
else:
raise AssertionError
# 1D array is not broadcast
if haystack_layout.purelist_depth != 1:
haystack_layout = ak.ravel(haystack_layout, highlevel=False)
def apply(layout, **kwargs):
if layout.purelist_depth == 1:
return is_in_1d_needle(layout, haystack_layout)
return ak.transform(apply, needle_layout)
assert ak.almost_equal(
is_in([[1, 3, 4, 2], [4], [1], []], [1, 2]),
[[True, False, False, True], [False], [True], []],
)
assert ak.almost_equal(
is_in([[1, 3, 4, 2], [4], [1], []], []),
[[False, False, False, False], [False], [False], []],
)
assert ak.almost_equal(
is_in([[1, 3, 4, 2], [4], [1], []], [None, 1, 2]),
[[True, False, False, True], [False], [True], []],
)
assert ak.almost_equal(
is_in([[1, 3, 4, 2], [4], [1], [], None], [1, 2]),
[[True, False, False, True], [False], [True], [], None],
)
assert ak.almost_equal(
is_in([[1, 3, 4, 2], [4], [1], [], [None]], [1, 2]),
[[True, False, False, True], [False], [True], [], [False]],
)
assert ak.almost_equal(
is_in([[1, 3, 4, 2], [4], [1], [None]], [None, 1, 2]),
[[True, False, False, True], [False], [True], [True]],
)
assert ak.almost_equal(
is_in([["hi", "bye"]], ["this", "bye"]),
[[False, True]],
)
assert ak.almost_equal(
is_in([["hi", "bye", None]], ["this", "bye"]),
[[False, True, False]],
)
assert ak.almost_equal(
is_in([["hi", "bye"]], ["this", "bye", None]),
[[False, True]],
)
assert ak.almost_equal(
is_in([["hi", "bye", None]], ["this", "bye", None]),
[[False, True, True]],
)
@agoose77
Copy link
Author

agoose77 commented Aug 2, 2023

I think _strings_in_strings_kernel should be faster than anything using NumPy's string kernels, because it can (IIRC) operate on non-contiguous string arrays. However, this implementation is linear in both the needle and haystack. We could use a binary search, if required.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment