Skip to content

Instantly share code, notes, and snippets.

@jorisvandenbossche
Created February 1, 2023 13:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jorisvandenbossche/fa2f72f07be0a328ab2454605ff34d77 to your computer and use it in GitHub Desktop.
Save jorisvandenbossche/fa2f72f07be0a328ab2454605ff34d77 to your computer and use it in GitHub Desktop.
import json
import pyarrow as pa
class InnerType(pa.ExtensionType):
def __init__(self):
pa.ExtensionType.__init__(self, pa.int64(), 'test.inner_type')
def __arrow_ext_serialize__(self):
return b""
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return InnerType()
def __eq__(self, other):
if isinstance(other, pa.BaseExtensionType):
return type(self) == type(other)
else:
return NotImplemented
class OuterAnnotatedType(pa.ExtensionType):
def __init__(self, storage_type, metadata):
self._metadata = metadata
pa.ExtensionType.__init__(self, storage_type, 'test.outer_annotated_type')
@property
def metadata(self):
return self._metadata
def __arrow_ext_serialize__(self):
metadata = {"metadata": self._metadata}
return json.dumps(metadata).encode()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
metadata = json.loads(serialized.decode())
return OuterAnnotatedType(storage_type, metadata["metadata"])
def __eq__(self, other):
if isinstance(other, pa.BaseExtensionType):
return (type(self) == type(other) and
self.metadata == other.metadata)
else:
return NotImplemented
t1 = InnerType()
pa.register_extension_type(t1)
t2 = OuterAnnotatedType(pa.int64(), "metadata")
pa.register_extension_type(t2)
storage = pa.array([1, 2, 3, 4], pa.int64())
arr1 = pa.ExtensionArray.from_storage(t1, storage)
t2 = OuterAnnotatedType(t1, "custom_info")
arr2 = pa.ExtensionArray.from_storage(t2, arr1)
batch = pa.RecordBatch.from_arrays([arr2], ["ext"])
# passthrough IPC
stream = pa.BufferOutputStream()
with pa.RecordBatchStreamWriter(stream, batch.schema) as writer:
writer.write_batch(batch)
buf = stream.getvalue()
reader = pa.RecordBatchStreamReader(buf)
batch2 = reader.read_next_batch()
print(batch["ext"].type.storage_type)
# extension<test.inner_type<InnerType>>
print(batch2["ext"].type.storage_type)
# DataType(int64)
pa.unregister_extension_type('test.inner_type')
pa.unregister_extension_type('test.outer_annotated_type')
stream = pa.BufferOutputStream()
with pa.RecordBatchStreamWriter(stream, batch.schema) as writer:
writer.write_batch(batch)
buf = stream.getvalue()
reader = pa.RecordBatchStreamReader(buf)
batch2 = reader.read_next_batch()
# now the main type is already int64
print(batch2["ext"].type)
# DataType(int64)
# the field metadata only has the metadata for the outer extension type
batch2.schema.field("ext").metadata
# {b'ARROW:extension:metadata': b'{"metadata": "custom_info"}',
# b'ARROW:extension:name': b'test.outer_annotated_type'}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment