Skip to content

Instantly share code, notes, and snippets.

@shaunc
Created August 12, 2023 23:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shaunc/d86854d74f4518935beb4d595bbac8e6 to your computer and use it in GitHub Desktop.
Save shaunc/d86854d74f4518935beb4d595bbac8e6 to your computer and use it in GitHub Desktop.
from collections import abc
from typing import Any, Iterator, cast
import numpy as np
from numba import cuda # type: ignore
TField = tuple[str, Any] | tuple[str, Any, Any]
_EXCLUDED_SEQ_TYPES = (str, bytes, bytearray, memoryview)
def named_tuples_to_records(
*args: Any, records_to_device: bool = False, align: bool = False
) -> tuple[Any, ...]:
"""
Convert named tuples in args to records.
if records_to_device is True, records corresponding to named tuples are
copied to device.
Returns arg list with named tuples converted to records.
"""
arg_list: list[Any] = list(args).copy() # type: ignore
for i, arg in enumerate(arg_list):
if hasattr(arg, "_asdict"):
arg = to_record(arg, align=align)[0]
if records_to_device:
arg = cuda.to_device(arg) # type: ignore
arg_list[i] = arg
return tuple(arg_list)
def to_record(obj: Any, align: bool = False) -> np.record:
"""
Convert object to record array.
We copy an object to a single-element record array. This is useful for
passing complex objects to CUDA kernels. NamedTuples passed to kernels
are passed by value, not by reference, so they are copied on every call,
and their complexity is limited to 512 elements (including all nested
elements). Record arrays are passed by reference, so they are not copied
Currently a limited number of types are supported.
Can be used like:
```
rec = record.to_record(obj)
```
Then for to send, for one-way trip:
```
rec_dev = cuda.to_device(rec)
some_kernel[grid, block](rec_dev[0])
```
or for round trip:
```
with cuda.pinned(rec):
some_kernel[grid, block](rec[0])
```
Args:
obj: object to convert
Returns:
record array
"""
fields: list[TField] = []
values: list[Any] = []
if hasattr(obj, "_asdict"): # type: ignore
# assume NamedTuple
items = cast(
Iterator[tuple[str, Any]], obj._asdict().items() # type: ignore
)
elif hasattr(obj, "__slots__"): # type: ignore
items = cast(
Iterator[tuple[str, Any]],
((name, getattr(obj, name)) for name in obj.__slots__),
)
elif isinstance(obj, abc.Sequence) and not isinstance( # type: ignore
obj, _EXCLUDED_SEQ_TYPES
):
# convert sequence to array and return directly, adding dimension
# as caller expects a record
value = _convert_seq_to_array(obj, align=align) # type: ignore
return value[None] # type: ignore
else:
items = cast(
Iterator[tuple[str, Any]], vars(obj).items() # type: ignore
)
for name, value in items:
if (
isinstance(value, abc.Sequence)
and not isinstance(value, _EXCLUDED_SEQ_TYPES)
and not hasattr(value, "_asdict") # type: ignore
):
# convert sequence to sub-array
value = _convert_seq_to_array(value, align=align) # type: ignore
if isinstance(value, np.ndarray):
fields.append((name, value.dtype, value.shape)) # type: ignore
values.append(value)
elif np.isscalar(value): # type: ignore
fields.append((name, type(value)))
values.append(value)
elif isinstance(value, object):
arr = to_record(value, align=align)
value = arr[0]
items = value.dtype.fields.items()
if len(items) == 0:
# TODO: empty sub-record is not supported by numba
# also, an empty value has shape (0,) while the dtype
# has shape (), which causes problems in numpy
raise RuntimeError(
"empty sub-record is not supported by numba, "
"and numpy has creation glitches"
)
fields.append((name, value.dtype))
values.append(arr[0])
else:
sub_fields = list[TField]()
for sub_name, p_sub_dtype in value.dtype.fields.items():
# in field list, 2nd element is dtype w/ position info
# we want dtype w/o position info
sub_dtype = p_sub_dtype[0]
if sub_dtype.shape == ():
sub_type = (
sub_dtype.type
if sub_dtype.type != np.void
else sub_dtype
)
sub_fields.append((sub_name, sub_type))
else:
sub_type = sub_dtype.subdtype[0].type
if sub_type == np.record:
sub_type = sub_dtype.subdtype[0]
sub_fields.append((sub_name, sub_type, sub_dtype.shape))
fields.append((name, sub_fields))
values.append(value)
else:
fields.append((name, type(value)))
values.append(value)
if len(values) == 0:
return np.rec.recarray(0, dtype=[])[None] # type: ignore
dtype = np.dtype(fields)
if align:
dtype = fixup_dtype_alignment(dtype)
rec = cast(
np.record,
np.rec.fromarrays(values, dtype=dtype)[None], # type: ignore
)
if align and not is_aligned(rec):
raise RuntimeError("record is not aligned")
return rec
def _convert_seq_to_array(
seq: abc.Sequence[Any], align: bool = False
) -> np.ndarray[Any, Any]:
"""
Convert sequence to array.
Args:
seq: sequence to convert
Returns:
array
"""
sub_values = [to_record(v, align=align)[0] for v in seq] # type: ignore
if len(sub_values) > 0:
value_dtype = sub_values[0].dtype
value_shape = (len(sub_values),) + sub_values[0].shape
value = np.array(sub_values, dtype=value_dtype).reshape(value_shape)
else:
value = np.zeros(0, dtype=np.int32)
return value
def is_aligned(rec: np.record) -> bool:
"""
Check alignment of record.
Args:
rec: record to check
Returns: True if record is aligned
"""
if rec.dtype.names is None:
return rec.flags["ALIGNED"]
for name in rec.dtype.names:
field = rec.dtype.fields[name] # type: ignore
if field[1] % min(rec[name].dtype.itemsize, 8) != 0:
return False
if not is_aligned(rec[name]):
return False
return True
def fixup_dtype_alignment(dtype: np.dtype[Any]) -> np.dtype[Any]:
"""
Fix up alignment of dtype.
A workaround for: https://github.com/numpy/numpy/issues/24339
Args:
dtype: dtype to fix up
Returns: dtype with alignment fixed up
"""
return _fixup_dtype_alignment(dtype)[0]
def _fixup_dtype_alignment(dtype: np.dtype[Any]) -> tuple[np.dtype[Any], int]:
"""
Fix up alignment of dtype recursively.
Aligns each element so that its contents are aligned, given what
is before it in the structure.
Args:
dtype: dtype to fix up
Returns fixed dtype and largest individual element size.
"""
if dtype.names is None:
if dtype.subdtype is not None:
sub_dtype, layout = dtype.subdtype[:2]
sub_dtype, size = _fixup_dtype_alignment(sub_dtype)
return np.dtype((sub_dtype, layout)), size
return dtype, dtype.itemsize
names = list[str]()
formats = list[Any]()
offsets = list[int]()
element_size = 0
offset = 0
for name in dtype.names:
field = dtype.fields[name] # type: ignore
sub_dtype = cast(np.dtype[Any], field[0])
sub_dtype, size = _fixup_dtype_alignment(sub_dtype)
element_size = max(element_size, size)
if offset % size != 0:
offset += size - offset % size
names.append(name)
formats.append(sub_dtype)
offsets.append(offset)
offset += sub_dtype.itemsize
if offset % element_size != 0:
offset += element_size - offset % element_size
if dtype.type == np.record:
new_dtype: np.dtype[Any] = np.dtype(
(
np.record,
dict(
names=names,
formats=formats,
offsets=offsets,
itemsize=offset,
),
),
align=True,
)
else:
new_dtype = np.dtype(
dict(
names=names, formats=formats, offsets=offsets, itemsize=offset
),
align=True,
)
return (new_dtype, element_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment