Last active
August 2, 2023 11:12
-
-
Save agoose77/313bf6e24146c9fd0f92b160b6cb0e5e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import awkward as ak | |
import numpy as np | |
import numba as nb | |
@nb.njit() | |
def _strings_in_strings_kernel(needle, haystack): | |
result = np.empty(len(needle), np.bool_) | |
for i, s in enumerate(needle): | |
for j in haystack: | |
if s == j: | |
break | |
else: | |
result[i] = False | |
continue | |
result[i] = True | |
return result | |
def is_in(needle, haystack): | |
needle_layout = ak.to_layout(needle, allow_record=True, allow_other=False) | |
haystack_layout = ak.to_layout(haystack, allow_record=False, allow_other=False) | |
if isinstance(needle_layout, ak.record.Record): | |
raise NotImplementedError("does not support ak.record.Record objects") | |
def is_in_1d_needle_all_string(needle_layout, haystack_layout): | |
out = _strings_in_strings_kernel( | |
ak.Array(needle_layout), ak.Array(haystack_layout) | |
) | |
return ak.to_layout(out) | |
def is_in_1d_needle_any_option(needle_layout, haystack_layout): | |
if needle_layout.is_option and not haystack_layout.is_option: | |
projected = is_in_1d_needle(needle_layout.project(), haystack_layout) | |
result = np.full(needle_layout.length, False) | |
result[needle_layout.mask_as_bool(True)] = projected.to_backend_array() | |
return ak.contents.NumpyArray(result) | |
elif not needle_layout.is_option and haystack_layout.is_option: | |
return is_in_1d_needle(needle_layout, haystack_layout.project()) | |
else: | |
tmp_result = is_in_1d_needle(needle_layout, haystack_layout.project()) | |
assert isinstance(tmp_result, ak.contents.NumpyArray) | |
if np.any(haystack_layout.mask_as_bool(True)): | |
tmp_result = ak.to_layout( | |
np.logical_or( | |
tmp_result.to_backend_array(), needle_layout.mask_as_bool(False) | |
) | |
) | |
return tmp_result | |
def is_in_1d_needle_any_unknown(needle_layout, haystack_layout): | |
return ak.to_layout(np.full(len(needle_layout), False)) | |
def is_in_1d_needle_all_numpy(needle_layout, haystack_layout): | |
return ak.to_layout( | |
np.isin( | |
needle_layout.to_backend_array(), haystack_layout.to_backend_array() | |
) | |
) | |
def is_in_1d_needle(needle_layout, haystack_layout): | |
if needle_layout.is_option or haystack_layout.is_option: | |
return is_in_1d_needle_any_option(needle_layout, haystack_layout) | |
elif needle_layout.is_unknown or haystack_layout.is_unknown: | |
return is_in_1d_needle_any_unknown(needle_layout, haystack_layout) | |
elif (needle_layout.parameter("__array__") in {"string", "bytestring"}) and ( | |
haystack_layout.parameter("__array__") in {"string", "bytestring"} | |
): | |
return is_in_1d_needle_all_string(needle_layout, haystack_layout) | |
elif needle_layout.is_numpy and haystack_layout.is_numpy: | |
return is_in_1d_needle_all_numpy(needle_layout, haystack_layout) | |
else: | |
raise AssertionError | |
# 1D array is not broadcast | |
if haystack_layout.purelist_depth != 1: | |
haystack_layout = ak.ravel(haystack_layout, highlevel=False) | |
def apply(layout, **kwargs): | |
if layout.purelist_depth == 1: | |
return is_in_1d_needle(layout, haystack_layout) | |
return ak.transform(apply, needle_layout) | |
assert ak.almost_equal( | |
is_in([[1, 3, 4, 2], [4], [1], []], [1, 2]), | |
[[True, False, False, True], [False], [True], []], | |
) | |
assert ak.almost_equal( | |
is_in([[1, 3, 4, 2], [4], [1], []], []), | |
[[False, False, False, False], [False], [False], []], | |
) | |
assert ak.almost_equal( | |
is_in([[1, 3, 4, 2], [4], [1], []], [None, 1, 2]), | |
[[True, False, False, True], [False], [True], []], | |
) | |
assert ak.almost_equal( | |
is_in([[1, 3, 4, 2], [4], [1], [], None], [1, 2]), | |
[[True, False, False, True], [False], [True], [], None], | |
) | |
assert ak.almost_equal( | |
is_in([[1, 3, 4, 2], [4], [1], [], [None]], [1, 2]), | |
[[True, False, False, True], [False], [True], [], [False]], | |
) | |
assert ak.almost_equal( | |
is_in([[1, 3, 4, 2], [4], [1], [None]], [None, 1, 2]), | |
[[True, False, False, True], [False], [True], [True]], | |
) | |
assert ak.almost_equal( | |
is_in([["hi", "bye"]], ["this", "bye"]), | |
[[False, True]], | |
) | |
assert ak.almost_equal( | |
is_in([["hi", "bye", None]], ["this", "bye"]), | |
[[False, True, False]], | |
) | |
assert ak.almost_equal( | |
is_in([["hi", "bye"]], ["this", "bye", None]), | |
[[False, True]], | |
) | |
assert ak.almost_equal( | |
is_in([["hi", "bye", None]], ["this", "bye", None]), | |
[[False, True, True]], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think
_strings_in_strings_kernel
should be faster than anything using NumPy's string kernels, because it can (IIRC) operate on non-contiguous string arrays. However, this implementation is linear in both the needle and haystack. We could use a binary search, if required.