Skip to content

Instantly share code, notes, and snippets.

@agoose77
Last active August 7, 2023 13:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save agoose77/28e5bb0250678e454356a85861a16368 to your computer and use it in GitHub Desktop.
Save agoose77/28e5bb0250678e454356a85861a16368 to your computer and use it in GitHub Desktop.
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
__all__ = ("join",)
import awkward as ak
from awkward._behavior import behavior_of
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
@high_level_function
def join(array, separator, *, highlevel=True, behavior=None):
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
separator (Array-like data (anything #ak.to_layout recognizes), str, or bytes): separator to insert between
strings. If an array-like, `separator` is broadcast against `array` which permits a unique separator
for each list of strings.
highlevel (bool): If True, return an #ak.Array; otherwise, return
a low-level #ak.contents.Content subclass.
behavior (None or dict): Custom #ak.behavior for the output array, if
high-level.
Concatenate the strings in `array`. The separator is inserted between each string.
Note: this function does not raise an error if the `array` does not contain any string or bytestring data.
Requires the pyarrow library and calls
[pyarrow.compute.binary_join](https://arrow.apache.org/docs/python/generated/pyarrow.compute.binary_join.html).
"""
# Dispatch
yield (array, separator)
# Implementation
return _impl(array, separator, highlevel, behavior)
def _is_maybe_optional_list_of_string(layout):
if layout.is_list and layout.parameter("__array__") in {"string", "bytestring"}:
return True
elif layout.is_option or layout.is_indexed:
return _is_maybe_optional_list_of_string(layout.content)
else:
return False
def _impl(array, separator, highlevel, behavior):
def apply_unary(layout, **kwargs):
if not (layout.is_list and layout.purelist_depth == 2):
return
if not _is_maybe_optional_list_of_string(layout.content):
return
# We have (maybe option/indexed type wrapping) strings
arrow_array = ak.to_arrow(
# Arrow needs an option type here
layout.copy(content=ak.contents.UnmaskedArray.simplified(layout.content)),
extensionarray=False,
# This kernel requires non-large string/bytestrings
string_to32=True,
bytestring_to32=True,
)
return ak.from_arrow(
pc.binary_join(arrow_array, separator),
highlevel=False,
)
def apply_binary(layouts, **kwargs):
layout, separator_layout = layouts
if not (layout.is_list and layout.purelist_depth == 2):
return
if not _is_maybe_optional_list_of_string(layout.content):
return
if not _is_maybe_optional_list_of_string(separator_layout):
raise TypeError(
f"separator must be a list of strings, not {type(separator_layout)}"
)
# We have (maybe option/indexed type wrapping) strings
layout_arrow = ak.to_arrow(
# Arrow needs an option type here
layout.copy(content=ak.contents.UnmaskedArray.simplified(layout.content)),
extensionarray=False,
# This kernel requires non-large string/bytestrings
string_to32=True,
bytestring_to32=True,
)
separator_arrow = ak.to_arrow(
separator_layout,
extensionarray=False,
# This kernel requires non-large string/bytestrings
string_to32=True,
bytestring_to32=True,
)
return (
ak.from_arrow(
pc.binary_join(layout_arrow, separator_arrow),
highlevel=False,
),
)
layout = ak.to_layout(array, allow_record=False, allow_other=True)
behavior = behavior_of(array, separator, behavior=behavior)
if isinstance(separator, (bytes, str)):
out = ak._do.recursively_apply(layout, apply_unary, behavior=behavior)
else:
separator_layout = ak.to_layout(separator, allow_record=False, allow_other=True)
(out,) = ak._broadcasting.broadcast_and_apply(
(layout, separator_layout), apply_binary, behavior
)
return wrap_layout(out, highlevel=highlevel, behavior=behavior)
array = ak.Array(
[
["this", "that"],
["foo", "bar", "baz"],
]
)
result = join(array, "-")
print(join(array, "-"))
separator = ak.Array(["→", "OO"])
print(join(array, separator))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment