Skip to content

Instantly share code, notes, and snippets.

@blink1073
Last active January 20, 2023 19:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save blink1073/289ef971393a2b0b1c3035d64c2dfe66 to your computer and use it in GitHub Desktop.
Save blink1073/289ef971393a2b0b1c3035d64c2dfe66 to your computer and use it in GitHub Desktop.
Pandas Extension Types for BSON
from __future__ import annotations
from bson import ObjectId, Decimal128, Binary
from pandas.api.extensions import ExtensionDtype, ExtensionArray, register_extension_dtype
from pandas._typing import type_t
import pyarrow as pa
from typing import Union, Any
import numpy as np
import pandas as pd
import numbers
class ObjectIdScalar(pa.ExtensionScalar):
def as_py(self):
return ObjectId(self.value.as_py())
class ObjectIdType(pa.PyExtensionType):
def __init__(self):
super().__init__(pa.binary(12))
def __reduce__(self):
return ObjectIdType, ()
def to_pandas_dtype(self):
return PandasObjectId()
def __arrow_ext_scalar_class__(self):
return ObjectIdScalar
class BSONDtype(ExtensionDtype):
na_value = np.nan
@classmethod
@property
def name(cls) -> str:
return f'bson_{cls.type}'
def __from_arrow__(
self, array: Union[pa.Array, pa.ChunkedArray]
) -> ExtensionArray:
if isinstance(array, pa.Array):
chunks = [array]
else:
# pyarrow.ChunkedArray
chunks = array.chunks
arr_type = self.construct_array_type()
results = []
for arr in chunks:
# Convert low level values to the desired type.
vals = []
typ = self.type
for val in np.array(arr):
if not pd.isna(val) and not isinstance(val, typ):
val = typ(val)
vals.append(val)
arr = np.array(vals)
# using _from_sequence to ensure None is converted to NA
to_append = arr_type._from_sequence(arr)
results.append(to_append)
if results:
return arr_type._concat_same_type(results)
else:
return arr_type(np.array([], dtype="object"))
class BaseExtensionArray(ExtensionArray):
def __init__(self, values, dtype=None, copy=False) -> None:
if not isinstance(values, np.ndarray):
raise TypeError("Need to pass a numpy array as values")
for val in values:
if not isinstance(val, self.dtype.type) and not pd.isna(val):
raise ValueError(f'Values must be either {self.dtype.type} or NA')
self.data = values
@classmethod
def _from_sequence(cls, scalars, dtype=None, copy=False):
data = np.empty(len(scalars), dtype=object)
data[:] = scalars
return cls(data)
@classmethod
def _from_factorized(cls, values, original):
return cls(values, dtype=original.dtype)
def __getitem__(self, item):
if isinstance(item, numbers.Integral):
return self.data[item]
else:
# slice, list-like, mask
item = pd.api.indexers.check_array_indexer(self, item)
return type(self)(self.data[item])
def __setitem__(self, item, value):
if not hasattr(value, '__iter__') and not isinstance(value, self.dtype.type) and not pd.isna(value):
raise ValueError(f'Value must be of type {self.dtype.type} or nan')
if not isinstance(item, numbers.Integral):
# slice, list-like, mask
item = pd.api.indexers.check_array_indexer(self, item)
elif not isinstance(value, self.dtype.type) and not pd.isna(value):
raise ValueError(f'Array element must be of type {self.dtype.type} or nan')
self.data[item] = value
def __len__(self) -> int:
return len(self.data)
def isna(self):
return np.array(
[not isinstance(x, self.dtype.type) and np.isnan(x) for x in self.data], dtype=bool
)
def __eq__(self, other):
return self.data == other
def nbytes(self):
return self.data.nbytes
def take(self, indexer, allow_fill=False, fill_value=None):
# re-implement here, since NumPy has trouble setting
# sized objects like UserDicts into scalar slots of
# an ndarary.
indexer = np.asarray(indexer)
msg = (
"Index is out of bounds or cannot do a "
"non-empty take from an empty array."
)
if allow_fill:
if fill_value is None:
fill_value = self.dtype.na_value
# bounds check
if (indexer < -1).any():
raise ValueError
try:
output = [
self.data[loc] if loc != -1 else fill_value for loc in indexer
]
except IndexError as err:
raise IndexError(msg) from err
else:
try:
output = [self.data[loc] for loc in indexer]
except IndexError as err:
raise IndexError(msg) from err
return self._from_sequence(output)
def copy(self):
return type(self)(self.data.copy())
@classmethod
def _concat_same_type(cls, to_concat):
data = np.concatenate([x.data for x in to_concat])
return cls(data)
@register_extension_dtype
class PandasObjectId(BSONDtype):
type = ObjectId
@classmethod
def construct_array_type(cls) -> type_t[PandasObjectIdArray]:
return PandasObjectIdArray
@register_extension_dtype
class PandasDecimal128(BSONDtype):
type = Decimal128
@classmethod
def construct_array_type(cls) -> type_t[PandasDecimal128Array]:
return PandasDecimal128Array
@register_extension_dtype
class PandasBinary(BSONDtype):
type = Binary
@classmethod
def construct_array_type(cls) -> type_t[PandasBinaryArray]:
return PandasBinaryArray
class PandasObjectIdArray(BaseExtensionArray):
dtype = PandasObjectId()
class PandasDecimal128Array(BaseExtensionArray):
dtype = PandasDecimal128()
class PandasBinaryArray(BaseExtensionArray):
dtype = PandasBinary()
def __array__(self, dtype=None):
return np.array(self.data, dtype)
def __eq__(self, other):
return np.array([a == other for a in self.data])
def __contains__(self, item: object) -> bool | np.bool_:
if pd.isna(item):
if not isinstance(item, float):
return False
return np.any([pd.isna(a) for a in self.data])
return np.any([a == item for a in self.data])
from pandas.tests.extension import base
import pytest
# def make_data():
# return [ObjectId() for _ in range(8)] + [np.nan] + [ObjectId() for _ in range(88)] + [np.nan] + [ObjectId(), ObjectId()]
# @pytest.fixture
# def dtype():
# return PandasObjectId()
def make_datum():
value = np.random.rand()
return Binary(str(value).encode('utf8'), 10)
def make_data():
return [make_datum() for _ in range(8)] + [np.nan] + [make_datum() for _ in range(88)] + [np.nan] + [make_datum(), make_datum()]
@pytest.fixture
def dtype():
return PandasBinary()
@pytest.fixture
def data(dtype):
return pd.array(make_data(), dtype=dtype)
@pytest.fixture
def data_for_twos(dtype):
return pd.array(np.ones(100), dtype=dtype)
@pytest.fixture
def data_missing(dtype):
return pd.array([np.nan, make_datum()], dtype=dtype)
@pytest.fixture
def data_for_sorting(dtype):
return pd.array([make_datum(), make_datum(), make_datum()], dtype=dtype)
@pytest.fixture
def data_missing_for_sorting(dtype):
return pd.array([make_datum(), np.nan, make_datum()], dtype=dtype)
@pytest.fixture
def na_value():
return np.nan
@pytest.fixture
def na_cmp():
def cmp(a, b):
return np.isnan(a) and np.isnan(b)
return cmp
@pytest.fixture(params=[True, False])
def box_in_series(request):
"""Whether to box the data in a Series"""
return request.param
@pytest.fixture(params=[True, False])
def as_array(request):
"""
Boolean fixture to support ExtensionDtype _from_sequence method testing.
"""
return request.param
@pytest.fixture(params=["ffill", "bfill"])
def fillna_method(request):
"""
Parametrized fixture giving method parameters 'ffill' and 'bfill' for
Series.fillna(method=<method>) testing.
"""
return request.param
@pytest.fixture
def invalid_scalar(data):
"""
A scalar that *cannot* be held by this ExtensionArray.
The default should work for most subclasses, but is not guaranteed.
If the array can hold any item (i.e. object dtype), then use pytest.skip.
"""
return object.__new__(object)
class TestDtype(base.BaseDtypeTests):
pass
class TestInterface(base.BaseInterfaceTests):
pass
class TestConstructors(base.BaseConstructorsTests):
pass
class TestGetitem(base.BaseGetitemTests):
pass
class TestSetitem(base.BaseSetitemTests):
pass
class TestIndex(base.BaseIndexTests):
pass
class TestMissing(base.BaseMissingTests):
pass
def test_to_pandas():
schema = pa.schema([("data", ObjectIdType())])
table = pa.Table.from_pydict(
{ "data": [ObjectId().binary, ObjectId().binary, ObjectId().binary, None]},
schema=schema)
df = table.to_pandas()
import pdb; pdb.set_trace()
pass
@blink1073
Copy link
Author

blink1073 commented Jan 14, 2023

Rollout plan:

  • Switch to pytest so we can use the pandas fixtures
  • Implement Binary extension types
  • Implement ObjectId pandas extension type
  • Implement Decimal128 as binary(16), as_py as Decimal128 and Pandas Decimal128 and deprecate Decimal128Str
  • Implement the other types: DBRef, Code, Int64, MaxKey, MinKey, Regex, Timestamp
  • Consider using the ArrowExtensionArray as the base class once it becomes stable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment