Created
August 12, 2023 23:49
-
-
Save shaunc/d86854d74f4518935beb4d595bbac8e6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import abc | |
from typing import Any, Iterator, cast | |
import numpy as np | |
from numba import cuda # type: ignore | |
TField = tuple[str, Any] | tuple[str, Any, Any] | |
_EXCLUDED_SEQ_TYPES = (str, bytes, bytearray, memoryview) | |
def named_tuples_to_records( | |
*args: Any, records_to_device: bool = False, align: bool = False | |
) -> tuple[Any, ...]: | |
""" | |
Convert named tuples in args to records. | |
if records_to_device is True, records corresponding to named tuples are | |
copied to device. | |
Returns arg list with named tuples converted to records. | |
""" | |
arg_list: list[Any] = list(args).copy() # type: ignore | |
for i, arg in enumerate(arg_list): | |
if hasattr(arg, "_asdict"): | |
arg = to_record(arg, align=align)[0] | |
if records_to_device: | |
arg = cuda.to_device(arg) # type: ignore | |
arg_list[i] = arg | |
return tuple(arg_list) | |
def to_record(obj: Any, align: bool = False) -> np.record: | |
""" | |
Convert object to record array. | |
We copy an object to a single-element record array. This is useful for | |
passing complex objects to CUDA kernels. NamedTuples passed to kernels | |
are passed by value, not by reference, so they are copied on every call, | |
and their complexity is limited to 512 elements (including all nested | |
elements). Record arrays are passed by reference, so they are not copied | |
Currently a limited number of types are supported. | |
Can be used like: | |
``` | |
rec = record.to_record(obj) | |
``` | |
Then for to send, for one-way trip: | |
``` | |
rec_dev = cuda.to_device(rec) | |
some_kernel[grid, block](rec_dev[0]) | |
``` | |
or for round trip: | |
``` | |
with cuda.pinned(rec): | |
some_kernel[grid, block](rec[0]) | |
``` | |
Args: | |
obj: object to convert | |
Returns: | |
record array | |
""" | |
fields: list[TField] = [] | |
values: list[Any] = [] | |
if hasattr(obj, "_asdict"): # type: ignore | |
# assume NamedTuple | |
items = cast( | |
Iterator[tuple[str, Any]], obj._asdict().items() # type: ignore | |
) | |
elif hasattr(obj, "__slots__"): # type: ignore | |
items = cast( | |
Iterator[tuple[str, Any]], | |
((name, getattr(obj, name)) for name in obj.__slots__), | |
) | |
elif isinstance(obj, abc.Sequence) and not isinstance( # type: ignore | |
obj, _EXCLUDED_SEQ_TYPES | |
): | |
# convert sequence to array and return directly, adding dimension | |
# as caller expects a record | |
value = _convert_seq_to_array(obj, align=align) # type: ignore | |
return value[None] # type: ignore | |
else: | |
items = cast( | |
Iterator[tuple[str, Any]], vars(obj).items() # type: ignore | |
) | |
for name, value in items: | |
if ( | |
isinstance(value, abc.Sequence) | |
and not isinstance(value, _EXCLUDED_SEQ_TYPES) | |
and not hasattr(value, "_asdict") # type: ignore | |
): | |
# convert sequence to sub-array | |
value = _convert_seq_to_array(value, align=align) # type: ignore | |
if isinstance(value, np.ndarray): | |
fields.append((name, value.dtype, value.shape)) # type: ignore | |
values.append(value) | |
elif np.isscalar(value): # type: ignore | |
fields.append((name, type(value))) | |
values.append(value) | |
elif isinstance(value, object): | |
arr = to_record(value, align=align) | |
value = arr[0] | |
items = value.dtype.fields.items() | |
if len(items) == 0: | |
# TODO: empty sub-record is not supported by numba | |
# also, an empty value has shape (0,) while the dtype | |
# has shape (), which causes problems in numpy | |
raise RuntimeError( | |
"empty sub-record is not supported by numba, " | |
"and numpy has creation glitches" | |
) | |
fields.append((name, value.dtype)) | |
values.append(arr[0]) | |
else: | |
sub_fields = list[TField]() | |
for sub_name, p_sub_dtype in value.dtype.fields.items(): | |
# in field list, 2nd element is dtype w/ position info | |
# we want dtype w/o position info | |
sub_dtype = p_sub_dtype[0] | |
if sub_dtype.shape == (): | |
sub_type = ( | |
sub_dtype.type | |
if sub_dtype.type != np.void | |
else sub_dtype | |
) | |
sub_fields.append((sub_name, sub_type)) | |
else: | |
sub_type = sub_dtype.subdtype[0].type | |
if sub_type == np.record: | |
sub_type = sub_dtype.subdtype[0] | |
sub_fields.append((sub_name, sub_type, sub_dtype.shape)) | |
fields.append((name, sub_fields)) | |
values.append(value) | |
else: | |
fields.append((name, type(value))) | |
values.append(value) | |
if len(values) == 0: | |
return np.rec.recarray(0, dtype=[])[None] # type: ignore | |
dtype = np.dtype(fields) | |
if align: | |
dtype = fixup_dtype_alignment(dtype) | |
rec = cast( | |
np.record, | |
np.rec.fromarrays(values, dtype=dtype)[None], # type: ignore | |
) | |
if align and not is_aligned(rec): | |
raise RuntimeError("record is not aligned") | |
return rec | |
def _convert_seq_to_array( | |
seq: abc.Sequence[Any], align: bool = False | |
) -> np.ndarray[Any, Any]: | |
""" | |
Convert sequence to array. | |
Args: | |
seq: sequence to convert | |
Returns: | |
array | |
""" | |
sub_values = [to_record(v, align=align)[0] for v in seq] # type: ignore | |
if len(sub_values) > 0: | |
value_dtype = sub_values[0].dtype | |
value_shape = (len(sub_values),) + sub_values[0].shape | |
value = np.array(sub_values, dtype=value_dtype).reshape(value_shape) | |
else: | |
value = np.zeros(0, dtype=np.int32) | |
return value | |
def is_aligned(rec: np.record) -> bool: | |
""" | |
Check alignment of record. | |
Args: | |
rec: record to check | |
Returns: True if record is aligned | |
""" | |
if rec.dtype.names is None: | |
return rec.flags["ALIGNED"] | |
for name in rec.dtype.names: | |
field = rec.dtype.fields[name] # type: ignore | |
if field[1] % min(rec[name].dtype.itemsize, 8) != 0: | |
return False | |
if not is_aligned(rec[name]): | |
return False | |
return True | |
def fixup_dtype_alignment(dtype: np.dtype[Any]) -> np.dtype[Any]: | |
""" | |
Fix up alignment of dtype. | |
A workaround for: https://github.com/numpy/numpy/issues/24339 | |
Args: | |
dtype: dtype to fix up | |
Returns: dtype with alignment fixed up | |
""" | |
return _fixup_dtype_alignment(dtype)[0] | |
def _fixup_dtype_alignment(dtype: np.dtype[Any]) -> tuple[np.dtype[Any], int]: | |
""" | |
Fix up alignment of dtype recursively. | |
Aligns each element so that its contents are aligned, given what | |
is before it in the structure. | |
Args: | |
dtype: dtype to fix up | |
Returns fixed dtype and largest individual element size. | |
""" | |
if dtype.names is None: | |
if dtype.subdtype is not None: | |
sub_dtype, layout = dtype.subdtype[:2] | |
sub_dtype, size = _fixup_dtype_alignment(sub_dtype) | |
return np.dtype((sub_dtype, layout)), size | |
return dtype, dtype.itemsize | |
names = list[str]() | |
formats = list[Any]() | |
offsets = list[int]() | |
element_size = 0 | |
offset = 0 | |
for name in dtype.names: | |
field = dtype.fields[name] # type: ignore | |
sub_dtype = cast(np.dtype[Any], field[0]) | |
sub_dtype, size = _fixup_dtype_alignment(sub_dtype) | |
element_size = max(element_size, size) | |
if offset % size != 0: | |
offset += size - offset % size | |
names.append(name) | |
formats.append(sub_dtype) | |
offsets.append(offset) | |
offset += sub_dtype.itemsize | |
if offset % element_size != 0: | |
offset += element_size - offset % element_size | |
if dtype.type == np.record: | |
new_dtype: np.dtype[Any] = np.dtype( | |
( | |
np.record, | |
dict( | |
names=names, | |
formats=formats, | |
offsets=offsets, | |
itemsize=offset, | |
), | |
), | |
align=True, | |
) | |
else: | |
new_dtype = np.dtype( | |
dict( | |
names=names, formats=formats, offsets=offsets, itemsize=offset | |
), | |
align=True, | |
) | |
return (new_dtype, element_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment