Last active
August 31, 2023 14:03
-
-
Save flying-sheep/99f2ceafdc494f97424222611b4f9474 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Pandas ExtensionArray / ExtensionDType for UUID |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/.venv/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"- Official docs: https://pandas.pydata.org/docs/development/extending.html\n", | |
"- StackOverflow: https://stackoverflow.com/a/68972163/247482\n", | |
"- Arrow integration: https://arrow.apache.org/docs/python/extending_types.html#defining-extension-types-user-defined-types\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import annotations\n", | |
"\n", | |
"import builtins\n", | |
"from uuid import UUID\n", | |
"from typing import TYPE_CHECKING, ClassVar, Self, Never, get_args\n", | |
"from collections.abc import Sequence, Iterable\n", | |
"\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"from numpy.typing import NDArray\n", | |
"from pandas.api.indexers import check_array_indexer\n", | |
"from pandas.api.extensions import ExtensionDtype, ExtensionArray\n", | |
"from pandas.core.ops.common import unpack_zerodim_and_defer\n", | |
"from pandas.core.algorithms import take\n", | |
"\n", | |
"if TYPE_CHECKING:\n", | |
" import pyarrow\n", | |
" from pandas.core.arrays import BooleanArray" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"UuidLike = UUID | bytes | int | str\n", | |
"\n", | |
"# 16 void bytes: 128 bit, every pattern valid, no funky behavior like 0 stripping.\n", | |
"_UuidNumpyDtype = np.dtype(\"V16\")\n", | |
"_UuidScalar = _UuidNumpyDtype.type\n", | |
"\n", | |
"\n", | |
"def _to_uuid(v: UuidLike) -> UUID:\n", | |
" match v:\n", | |
" case UUID():\n", | |
" return v\n", | |
" case bytes():\n", | |
" return UUID(bytes=v)\n", | |
" case int():\n", | |
" return UUID(int=v)\n", | |
" case str():\n", | |
" return UUID(v)\n", | |
" msg = f\"Unknown type for Uuid: {type(v)} is not {get_args(UuidLike)}\"\n", | |
" raise TypeError(msg)\n", | |
"\n", | |
"\n", | |
"class UuidDtype(ExtensionDtype):\n", | |
" # ExtensionDtype essential API (3 class attrs and methods)\n", | |
"\n", | |
" name: ClassVar[str] = \"uuid\"\n", | |
" type: ClassVar[builtins.type[UUID]] = UUID\n", | |
"\n", | |
" @classmethod\n", | |
" def construct_array_type(cls) -> type[UuidExtensionArray]:\n", | |
" return UuidExtensionArray\n", | |
"\n", | |
" # ExtensionDtype overrides\n", | |
"\n", | |
" kind: ClassVar[str] = _UuidNumpyDtype.kind\n", | |
" # index_class: ClassVar[type[pd.Index]] = pd.Index\n", | |
"\n", | |
" @property\n", | |
" def na_value(self) -> Never:\n", | |
" # TODO: figure this out\n", | |
" raise NotImplementedError()\n", | |
"\n", | |
" # IO\n", | |
"\n", | |
" def __from_arrow__(self, array: pyarrow.Array) -> ExtensionArray:\n", | |
" ...\n", | |
"\n", | |
"\n", | |
"class UuidExtensionArray(ExtensionArray):\n", | |
" # Implementation details and convenience\n", | |
"\n", | |
" _data: NDArray[_UuidScalar]\n", | |
"\n", | |
" def __init__(self, values: Iterable[UuidLike], *, copy: bool = False) -> None:\n", | |
" if isinstance(values, np.ndarray):\n", | |
" self._data = values.astype(_UuidNumpyDtype, copy=copy)\n", | |
" else:\n", | |
" # TODO: more efficient\n", | |
" self._data = np.array(\n", | |
" [_to_uuid(x).bytes for x in values], dtype=_UuidNumpyDtype\n", | |
" )\n", | |
"\n", | |
" if self._data.ndim != 1:\n", | |
" raise ValueError(\"Array only supports 1-d arrays\")\n", | |
"\n", | |
" # ExtensionArray essential API (11 class attrs and methods)\n", | |
"\n", | |
" dtype: ClassVar[UuidDtype] = UuidDtype()\n", | |
"\n", | |
" @classmethod\n", | |
" def _from_sequence(\n", | |
" cls,\n", | |
" data: Iterable[UuidLike],\n", | |
" dtype: UuidDtype | None = None,\n", | |
" copy: bool = False,\n", | |
" ) -> Self:\n", | |
" if dtype is None:\n", | |
" dtype = UuidDtype()\n", | |
"\n", | |
" if not isinstance(dtype, UuidDtype):\n", | |
" msg = f\"'{cls.__name__}' only supports 'UuidDtype' dtype\"\n", | |
" raise TypeError(msg)\n", | |
" return cls(data, copy=copy)\n", | |
"\n", | |
" def __getitem__(self, index) -> Self | UUID:\n", | |
" if isinstance(index, int):\n", | |
" return UUID(bytes=self._data[index].tobytes())\n", | |
" index = check_array_indexer(self, index)\n", | |
" return self._simple_new(self._data[index])\n", | |
"\n", | |
" # def __setitem__(self, index, value):\n", | |
"\n", | |
" def __len__(self) -> int:\n", | |
" return len(self._data)\n", | |
"\n", | |
" @unpack_zerodim_and_defer(\"__eq__\")\n", | |
" def __eq__(self, other):\n", | |
" return self._cmp(\"eq\", other)\n", | |
"\n", | |
" def nbytes(self) -> int:\n", | |
" return self._data.nbytes\n", | |
"\n", | |
" def isna(self) -> NDArray[np.bool_]:\n", | |
" return pd.isna(self._data)\n", | |
"\n", | |
" def take(\n", | |
" self, indexer, *, allow_fill: bool = False, fill_value: UUID | None = None\n", | |
" ) -> Self:\n", | |
" if allow_fill and fill_value is None:\n", | |
" fill_value = self.dtype.na_value\n", | |
"\n", | |
" result = take(self._data, indexer, allow_fill=allow_fill, fill_value=fill_value)\n", | |
" return self._simple_new(result)\n", | |
"\n", | |
" def copy(self) -> Self:\n", | |
" return self._simple_new(self._data.copy())\n", | |
"\n", | |
" @classmethod\n", | |
" def _concat_same_type(cls, to_concat: Sequence[Self]) -> Self:\n", | |
" return cls._simple_new(np.concatenate([x._data for x in to_concat]))\n", | |
"\n", | |
" # Helpers\n", | |
"\n", | |
" @classmethod\n", | |
" def _simple_new(cls, values: NDArray[_UuidScalar]) -> Self:\n", | |
" result = UuidExtensionArray.__new__(cls)\n", | |
" result._data = values\n", | |
" return result\n", | |
"\n", | |
" def _cmp(self, op: str, other) -> BooleanArray:\n", | |
" if isinstance(other, UuidExtensionArray):\n", | |
" other = other._data\n", | |
" elif isinstance(other, Sequence):\n", | |
" other = np.asarray(other)\n", | |
" if other.ndim > 1:\n", | |
" raise NotImplementedError(\"can only perform ops with 1-d structures\")\n", | |
" if len(self) != len(other):\n", | |
" raise ValueError(\"Lengths must match to compare\")\n", | |
"\n", | |
" method = getattr(self._data, f\"__{op}__\")\n", | |
" result = method(other)\n", | |
"\n", | |
" # if result is NotImplemented:\n", | |
" # result = invalid_comparison(self._data, other, op)\n", | |
"\n", | |
" rv: BooleanArray = pd.array(result, dtype=\"boolean\") # type: ignore\n", | |
" return rv\n", | |
"\n", | |
" # IO\n", | |
"\n", | |
" def __arrow_array__(self, type=None):\n", | |
" \"\"\"convert the underlying array values to a pyarrow Array\"\"\"\n", | |
" import pyarrow\n", | |
"\n", | |
" return pyarrow.array(..., type=type)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Pandas has a bug around void dtypes: https://github.com/pandas-dev/pandas/issues/54810\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from contextlib import contextmanager\n", | |
"\n", | |
"\n", | |
"@contextmanager\n", | |
"def patch_pandas_constructors():\n", | |
" @classmethod\n", | |
" def _validate_dtype(\n", | |
" cls, dtype: np.dtype | ExtensionDtype\n", | |
" ) -> np.dtype | ExtensionDtype | None:\n", | |
" if dtype is None:\n", | |
" return None\n", | |
"\n", | |
" from pandas.core.dtypes.common import pandas_dtype\n", | |
"\n", | |
" dtype = pandas_dtype(dtype)\n", | |
" # a compound dtype\n", | |
" if getattr(dtype, \"fields\", None) is not None:\n", | |
" raise NotImplementedError(\n", | |
" \"compound dtypes are not implemented \"\n", | |
" f\"in the {cls.__name__} constructor\"\n", | |
" )\n", | |
"\n", | |
" return dtype\n", | |
"\n", | |
" from unittest.mock import patch\n", | |
" from pandas.core.generic import NDFrame\n", | |
"\n", | |
" with patch.object(NDFrame, \"_validate_dtype\", _validate_dtype):\n", | |
" yield" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<UuidExtensionArray>\n", | |
"[UUID('00000000-0000-0000-0000-000000000000'), UUID('5748c75a-4418-4bc2-81b8-59703a6ba0cd')]\n", | |
"Length: 2, dtype: uuid" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from uuid import uuid4\n", | |
"\n", | |
"UuidExtensionArray([0, uuid4()])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0 8be4f607-afaa-4f34-868d-727bc9088999\n", | |
"Name: s, dtype: uuid" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"with patch_pandas_constructors():\n", | |
" s = pd.Series([uuid4()], dtype=UuidDtype(), name=\"s\")\n", | |
"s" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>s</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>8be4f607-afaa-4f34-868d-727bc9088999</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" s\n", | |
"0 8be4f607-afaa-4f34-868d-727bc9088999" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pd.DataFrame(s)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Numeric Python", | |
"language": "python", | |
"name": "numeric" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.5" | |
}, | |
"orig_nbformat": 4 | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[project] | |
name = "numeric-notebooks" | |
version = "0.1.0" | |
dependencies = ["numpy", "scipy", "pandas", "pyarrow"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment