Skip to content

Instantly share code, notes, and snippets.

@flying-sheep
Last active August 31, 2023 14:03
Show Gist options
  • Save flying-sheep/99f2ceafdc494f97424222611b4f9474 to your computer and use it in GitHub Desktop.
Save flying-sheep/99f2ceafdc494f97424222611b4f9474 to your computer and use it in GitHub Desktop.
Pandas ExtensionArray / ExtensionDType for UUID
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Official docs: https://pandas.pydata.org/docs/development/extending.html\n",
"- StackOverflow: https://stackoverflow.com/a/68972163/247482\n",
"- Arrow integration: https://arrow.apache.org/docs/python/extending_types.html#defining-extension-types-user-defined-types\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"import builtins\n",
"from uuid import UUID\n",
"from typing import TYPE_CHECKING, ClassVar, Self, Never, get_args\n",
"from collections.abc import Sequence, Iterable\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from numpy.typing import NDArray\n",
"from pandas.api.indexers import check_array_indexer\n",
"from pandas.api.extensions import ExtensionDtype, ExtensionArray\n",
"from pandas.core.ops.common import unpack_zerodim_and_defer\n",
"from pandas.core.algorithms import take\n",
"\n",
"if TYPE_CHECKING:\n",
" import pyarrow\n",
" from pandas.core.arrays import BooleanArray"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"UuidLike = UUID | bytes | int | str\n",
"\n",
"# 16 void bytes: 128 bit, every pattern valid, no funky behavior like 0 stripping.\n",
"_UuidNumpyDtype = np.dtype(\"V16\")\n",
"_UuidScalar = _UuidNumpyDtype.type\n",
"\n",
"\n",
"def _to_uuid(v: UuidLike) -> UUID:\n",
" match v:\n",
" case UUID():\n",
" return v\n",
" case bytes():\n",
" return UUID(bytes=v)\n",
" case int():\n",
" return UUID(int=v)\n",
" case str():\n",
" return UUID(v)\n",
" msg = f\"Unknown type for Uuid: {type(v)} is not {get_args(UuidLike)}\"\n",
" raise TypeError(msg)\n",
"\n",
"\n",
"class UuidDtype(ExtensionDtype):\n",
" # ExtensionDtype essential API (3 class attrs and methods)\n",
"\n",
" name: ClassVar[str] = \"uuid\"\n",
" type: ClassVar[builtins.type[UUID]] = UUID\n",
"\n",
" @classmethod\n",
" def construct_array_type(cls) -> type[UuidExtensionArray]:\n",
" return UuidExtensionArray\n",
"\n",
" # ExtensionDtype overrides\n",
"\n",
" kind: ClassVar[str] = _UuidNumpyDtype.kind\n",
" # index_class: ClassVar[type[pd.Index]] = pd.Index\n",
"\n",
" @property\n",
" def na_value(self) -> Never:\n",
" # TODO: figure this out\n",
" raise NotImplementedError()\n",
"\n",
" # IO\n",
"\n",
" def __from_arrow__(self, array: pyarrow.Array) -> ExtensionArray:\n",
" ...\n",
"\n",
"\n",
"class UuidExtensionArray(ExtensionArray):\n",
" # Implementation details and convenience\n",
"\n",
" _data: NDArray[_UuidScalar]\n",
"\n",
" def __init__(self, values: Iterable[UuidLike], *, copy: bool = False) -> None:\n",
" if isinstance(values, np.ndarray):\n",
" self._data = values.astype(_UuidNumpyDtype, copy=copy)\n",
" else:\n",
" # TODO: more efficient\n",
" self._data = np.array(\n",
" [_to_uuid(x).bytes for x in values], dtype=_UuidNumpyDtype\n",
" )\n",
"\n",
" if self._data.ndim != 1:\n",
" raise ValueError(\"Array only supports 1-d arrays\")\n",
"\n",
" # ExtensionArray essential API (11 class attrs and methods)\n",
"\n",
" dtype: ClassVar[UuidDtype] = UuidDtype()\n",
"\n",
" @classmethod\n",
" def _from_sequence(\n",
" cls,\n",
" data: Iterable[UuidLike],\n",
" dtype: UuidDtype | None = None,\n",
" copy: bool = False,\n",
" ) -> Self:\n",
" if dtype is None:\n",
" dtype = UuidDtype()\n",
"\n",
" if not isinstance(dtype, UuidDtype):\n",
" msg = f\"'{cls.__name__}' only supports 'UuidDtype' dtype\"\n",
" raise TypeError(msg)\n",
" return cls(data, copy=copy)\n",
"\n",
" def __getitem__(self, index) -> Self | UUID:\n",
" if isinstance(index, int):\n",
" return UUID(bytes=self._data[index].tobytes())\n",
" index = check_array_indexer(self, index)\n",
" return self._simple_new(self._data[index])\n",
"\n",
" # def __setitem__(self, index, value):\n",
"\n",
" def __len__(self) -> int:\n",
" return len(self._data)\n",
"\n",
" @unpack_zerodim_and_defer(\"__eq__\")\n",
" def __eq__(self, other):\n",
" return self._cmp(\"eq\", other)\n",
"\n",
" def nbytes(self) -> int:\n",
" return self._data.nbytes\n",
"\n",
" def isna(self) -> NDArray[np.bool_]:\n",
" return pd.isna(self._data)\n",
"\n",
" def take(\n",
" self, indexer, *, allow_fill: bool = False, fill_value: UUID | None = None\n",
" ) -> Self:\n",
" if allow_fill and fill_value is None:\n",
" fill_value = self.dtype.na_value\n",
"\n",
" result = take(self._data, indexer, allow_fill=allow_fill, fill_value=fill_value)\n",
" return self._simple_new(result)\n",
"\n",
" def copy(self) -> Self:\n",
" return self._simple_new(self._data.copy())\n",
"\n",
" @classmethod\n",
" def _concat_same_type(cls, to_concat: Sequence[Self]) -> Self:\n",
" return cls._simple_new(np.concatenate([x._data for x in to_concat]))\n",
"\n",
" # Helpers\n",
"\n",
" @classmethod\n",
" def _simple_new(cls, values: NDArray[_UuidScalar]) -> Self:\n",
" result = UuidExtensionArray.__new__(cls)\n",
" result._data = values\n",
" return result\n",
"\n",
" def _cmp(self, op: str, other) -> BooleanArray:\n",
" if isinstance(other, UuidExtensionArray):\n",
" other = other._data\n",
" elif isinstance(other, Sequence):\n",
" other = np.asarray(other)\n",
" if other.ndim > 1:\n",
" raise NotImplementedError(\"can only perform ops with 1-d structures\")\n",
" if len(self) != len(other):\n",
" raise ValueError(\"Lengths must match to compare\")\n",
"\n",
" method = getattr(self._data, f\"__{op}__\")\n",
" result = method(other)\n",
"\n",
" # if result is NotImplemented:\n",
" # result = invalid_comparison(self._data, other, op)\n",
"\n",
" rv: BooleanArray = pd.array(result, dtype=\"boolean\") # type: ignore\n",
" return rv\n",
"\n",
" # IO\n",
"\n",
" def __arrow_array__(self, type=None):\n",
" \"\"\"convert the underlying array values to a pyarrow Array\"\"\"\n",
" import pyarrow\n",
"\n",
" return pyarrow.array(..., type=type)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pandas has a bug around void dtypes: https://github.com/pandas-dev/pandas/issues/54810\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from contextlib import contextmanager\n",
"\n",
"\n",
"@contextmanager\n",
"def patch_pandas_constructors():\n",
" @classmethod\n",
" def _validate_dtype(\n",
" cls, dtype: np.dtype | ExtensionDtype\n",
" ) -> np.dtype | ExtensionDtype | None:\n",
" if dtype is None:\n",
" return None\n",
"\n",
" from pandas.core.dtypes.common import pandas_dtype\n",
"\n",
" dtype = pandas_dtype(dtype)\n",
" # a compound dtype\n",
" if getattr(dtype, \"fields\", None) is not None:\n",
" raise NotImplementedError(\n",
" \"compound dtypes are not implemented \"\n",
" f\"in the {cls.__name__} constructor\"\n",
" )\n",
"\n",
" return dtype\n",
"\n",
" from unittest.mock import patch\n",
" from pandas.core.generic import NDFrame\n",
"\n",
" with patch.object(NDFrame, \"_validate_dtype\", _validate_dtype):\n",
" yield"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<UuidExtensionArray>\n",
"[UUID('00000000-0000-0000-0000-000000000000'), UUID('5748c75a-4418-4bc2-81b8-59703a6ba0cd')]\n",
"Length: 2, dtype: uuid"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from uuid import uuid4\n",
"\n",
"UuidExtensionArray([0, uuid4()])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 8be4f607-afaa-4f34-868d-727bc9088999\n",
"Name: s, dtype: uuid"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with patch_pandas_constructors():\n",
" s = pd.Series([uuid4()], dtype=UuidDtype(), name=\"s\")\n",
"s"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>s</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8be4f607-afaa-4f34-868d-727bc9088999</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" s\n",
"0 8be4f607-afaa-4f34-868d-727bc9088999"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(s)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Numeric Python",
"language": "python",
"name": "numeric"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
[project]
name = "numeric-notebooks"
version = "0.1.0"
dependencies = ["numpy", "scipy", "pandas", "pyarrow"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment