Skip to content

Instantly share code, notes, and snippets.

@AlenkaF
Last active January 26, 2023 08:56
Show Gist options
  • Save AlenkaF/95fb41f461fb792396bb20dd502b4112 to your computer and use it in GitHub Desktop.
Save AlenkaF/95fb41f461fb792396bb20dd502b4112 to your computer and use it in GitHub Desktop.
Example of tensor extension with tests in PyArrow
import ast
import json
import math
import numpy as np
import pyarrow as pa
class TensorType(pa.ExtensionType):
def __init__(self, value_type, shape, order):
self._value_type = value_type
self._shape = shape
self._order = order
size = math.prod(shape)
pa.ExtensionType.__init__(self, pa.list_(self._value_type, size),
'arrow.tensor')
@property
def dtype(self):
return self._value_type
@property
def shape(self):
return self._shape
@property
def order(self):
"""
Order of sorting, can be row or column ('C', 'F')
"""
return self._order
def __arrow_ext_serialize__(self):
metadata = {"shape": str(self._shape),
"order": self._order}
return json.dumps(metadata).encode()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
# return an instance of this subclass given the serialized
# metadata.
assert serialized.decode().startswith('{"shape":')
metadata = json.loads(serialized.decode())
shape = ast.literal_eval(metadata['shape'])
order = metadata["order"]
return TensorType(storage_type.value_type, shape, order)
def __arrow_ext_class__(self):
return TensorArray
class TensorArray(pa.ExtensionArray):
def to_numpy_tensor(self):
flat_array = self.storage.flatten().to_numpy()
return flat_array.reshape((self.type.shape),
order=self.type.order)
def from_numpy_tensor(obj):
numpy_type = obj.flatten().dtype
arrow_type = pa.from_numpy_dtype(numpy_type)
order = 'F' if np.isfortran(obj) else 'C'
size = obj.size
return pa.ExtensionArray.from_storage(
TensorType(arrow_type, obj.shape, order),
pa.array([obj.flatten()], pa.list_(arrow_type, size))
)
import numpy as np
import pyarrow as pa
import pytest
@pytest.fixture
def registered_tensor_type():
# setup
tensor_type = TensorType(pa.int8(), (2, 2, 3), 'C')
tensor_class = tensor_type.__arrow_ext_class__()
pa.register_extension_type(tensor_type)
yield tensor_type, tensor_class
# teardown
try:
pa.unregister_extension_type('arrow.tensor')
except KeyError:
pass
def test_generic_ext_type():
tensor_type = TensorType(pa.int8(), (2,3), 'C')
assert tensor_type.extension_name == "arrow.tensor"
assert tensor_type.storage_type == pa.list_(pa.int8(), 6)
def test_tensor_ext_class_methods():
tensor_type = TensorType(pa.float32(), (2, 2, 3), 'C')
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.float32(), 12))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
expected = np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=np.float32)
result = arr.to_numpy_tensor()
np.testing.assert_array_equal(result, expected)
tensor_array_from_numpy = TensorArray.from_numpy_tensor(expected)
assert isinstance(tensor_array_from_numpy.type, TensorType)
assert tensor_array_from_numpy.type.dtype == pa.float32()
assert tensor_array_from_numpy.type.shape == (2, 2, 3)
assert tensor_array_from_numpy.type.order == 'C'
def ipc_write_batch(batch):
stream = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
writer.close()
return stream.getvalue()
def ipc_read_batch(buf):
reader = pa.RecordBatchStreamReader(buf)
return reader.read_next_batch()
def test_generic_ext_type_ipc(registered_tensor_type):
tensor_type, tensor_class = registered_tensor_type
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
# check the built array has exactly the expected clss
assert type(arr) == tensor_class
buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)
result = batch.column(0)
# check the deserialized array class is the expected one
assert type(result) == tensor_class
assert result.type.extension_name == "arrow.tensor"
assert arr.storage.to_pylist() == [[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]]
# we get back an actual TensorType
assert isinstance(result.type, TensorType)
assert result.type.dtype == pa.int8()
assert result.type.shape == (2, 2, 3)
assert result.type.order == 'C'
# using different parametrization as how it was registered
tensor_type_uint = tensor_type.__class__(pa.uint8(), (2, 3), 'C')
assert tensor_type_uint.extension_name == "arrow.tensor"
assert tensor_type_uint.dtype == pa.uint8()
assert tensor_type_uint.shape == (2, 3)
assert tensor_type_uint.order == 'C'
storage = pa.array([[1, 2, 3, 4, 5, 6]], pa.list_(pa.uint8(), 6))
arr = pa.ExtensionArray.from_storage(tensor_type_uint, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)
result = batch.column(0)
assert isinstance(result.type, TensorType)
assert result.type.dtype == pa.uint8()
assert result.type.shape == (2, 3)
assert result.type.order == 'C'
assert type(result) == tensor_class
# def test_generic_ext_type_ipc_unknown(registered_tensor_type):
# def test_generic_ext_type_equality():
# def test_generic_ext_type_register(registered_tensor_type):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment