Skip to content

Instantly share code, notes, and snippets.

Last active January 20, 2023 19:16
Show Gist options
  • Save blink1073/289ef971393a2b0b1c3035d64c2dfe66 to your computer and use it in GitHub Desktop.
Save blink1073/289ef971393a2b0b1c3035d64c2dfe66 to your computer and use it in GitHub Desktop.
Pandas Extension Types for BSON
from __future__ import annotations
from bson import ObjectId, Decimal128, Binary
from pandas.api.extensions import ExtensionDtype, ExtensionArray, register_extension_dtype
from pandas._typing import type_t
import pyarrow as pa
from typing import Union, Any
import numpy as np
import pandas as pd
import numbers
class ObjectIdScalar(pa.ExtensionScalar):
def as_py(self):
return ObjectId(self.value.as_py())
class ObjectIdType(pa.PyExtensionType):
def __init__(self):
def __reduce__(self):
return ObjectIdType, ()
def to_pandas_dtype(self):
return PandasObjectId()
def __arrow_ext_scalar_class__(self):
return ObjectIdScalar
class BSONDtype(ExtensionDtype):
na_value = np.nan
def name(cls) -> str:
return f'bson_{cls.type}'
def __from_arrow__(
self, array: Union[pa.Array, pa.ChunkedArray]
) -> ExtensionArray:
if isinstance(array, pa.Array):
chunks = [array]
# pyarrow.ChunkedArray
chunks = array.chunks
arr_type = self.construct_array_type()
results = []
for arr in chunks:
# Convert low level values to the desired type.
vals = []
typ = self.type
for val in np.array(arr):
if not pd.isna(val) and not isinstance(val, typ):
val = typ(val)
arr = np.array(vals)
# using _from_sequence to ensure None is converted to NA
to_append = arr_type._from_sequence(arr)
if results:
return arr_type._concat_same_type(results)
return arr_type(np.array([], dtype="object"))
class BaseExtensionArray(ExtensionArray):
def __init__(self, values, dtype=None, copy=False) -> None:
if not isinstance(values, np.ndarray):
raise TypeError("Need to pass a numpy array as values")
for val in values:
if not isinstance(val, self.dtype.type) and not pd.isna(val):
raise ValueError(f'Values must be either {self.dtype.type} or NA') = values
def _from_sequence(cls, scalars, dtype=None, copy=False):
data = np.empty(len(scalars), dtype=object)
data[:] = scalars
return cls(data)
def _from_factorized(cls, values, original):
return cls(values, dtype=original.dtype)
def __getitem__(self, item):
if isinstance(item, numbers.Integral):
# slice, list-like, mask
item = pd.api.indexers.check_array_indexer(self, item)
return type(self)([item])
def __setitem__(self, item, value):
if not hasattr(value, '__iter__') and not isinstance(value, self.dtype.type) and not pd.isna(value):
raise ValueError(f'Value must be of type {self.dtype.type} or nan')
if not isinstance(item, numbers.Integral):
# slice, list-like, mask
item = pd.api.indexers.check_array_indexer(self, item)
elif not isinstance(value, self.dtype.type) and not pd.isna(value):
raise ValueError(f'Array element must be of type {self.dtype.type} or nan')[item] = value
def __len__(self) -> int:
return len(
def isna(self):
return np.array(
[not isinstance(x, self.dtype.type) and np.isnan(x) for x in], dtype=bool
def __eq__(self, other):
return == other
def nbytes(self):
def take(self, indexer, allow_fill=False, fill_value=None):
# re-implement here, since NumPy has trouble setting
# sized objects like UserDicts into scalar slots of
# an ndarary.
indexer = np.asarray(indexer)
msg = (
"Index is out of bounds or cannot do a "
"non-empty take from an empty array."
if allow_fill:
if fill_value is None:
fill_value = self.dtype.na_value
# bounds check
if (indexer < -1).any():
raise ValueError
output = [[loc] if loc != -1 else fill_value for loc in indexer
except IndexError as err:
raise IndexError(msg) from err
output = [[loc] for loc in indexer]
except IndexError as err:
raise IndexError(msg) from err
return self._from_sequence(output)
def copy(self):
return type(self)(
def _concat_same_type(cls, to_concat):
data = np.concatenate([ for x in to_concat])
return cls(data)
class PandasObjectId(BSONDtype):
type = ObjectId
def construct_array_type(cls) -> type_t[PandasObjectIdArray]:
return PandasObjectIdArray
class PandasDecimal128(BSONDtype):
type = Decimal128
def construct_array_type(cls) -> type_t[PandasDecimal128Array]:
return PandasDecimal128Array
class PandasBinary(BSONDtype):
type = Binary
def construct_array_type(cls) -> type_t[PandasBinaryArray]:
return PandasBinaryArray
class PandasObjectIdArray(BaseExtensionArray):
dtype = PandasObjectId()
class PandasDecimal128Array(BaseExtensionArray):
dtype = PandasDecimal128()
class PandasBinaryArray(BaseExtensionArray):
dtype = PandasBinary()
def __array__(self, dtype=None):
return np.array(, dtype)
def __eq__(self, other):
return np.array([a == other for a in])
def __contains__(self, item: object) -> bool | np.bool_:
if pd.isna(item):
if not isinstance(item, float):
return False
return np.any([pd.isna(a) for a in])
return np.any([a == item for a in])
from pandas.tests.extension import base
import pytest
# def make_data():
# return [ObjectId() for _ in range(8)] + [np.nan] + [ObjectId() for _ in range(88)] + [np.nan] + [ObjectId(), ObjectId()]
# @pytest.fixture
# def dtype():
# return PandasObjectId()
def make_datum():
value = np.random.rand()
return Binary(str(value).encode('utf8'), 10)
def make_data():
return [make_datum() for _ in range(8)] + [np.nan] + [make_datum() for _ in range(88)] + [np.nan] + [make_datum(), make_datum()]
def dtype():
return PandasBinary()
def data(dtype):
return pd.array(make_data(), dtype=dtype)
def data_for_twos(dtype):
return pd.array(np.ones(100), dtype=dtype)
def data_missing(dtype):
return pd.array([np.nan, make_datum()], dtype=dtype)
def data_for_sorting(dtype):
return pd.array([make_datum(), make_datum(), make_datum()], dtype=dtype)
def data_missing_for_sorting(dtype):
return pd.array([make_datum(), np.nan, make_datum()], dtype=dtype)
def na_value():
return np.nan
def na_cmp():
def cmp(a, b):
return np.isnan(a) and np.isnan(b)
return cmp
@pytest.fixture(params=[True, False])
def box_in_series(request):
"""Whether to box the data in a Series"""
return request.param
@pytest.fixture(params=[True, False])
def as_array(request):
Boolean fixture to support ExtensionDtype _from_sequence method testing.
return request.param
@pytest.fixture(params=["ffill", "bfill"])
def fillna_method(request):
Parametrized fixture giving method parameters 'ffill' and 'bfill' for
Series.fillna(method=<method>) testing.
return request.param
def invalid_scalar(data):
A scalar that *cannot* be held by this ExtensionArray.
The default should work for most subclasses, but is not guaranteed.
If the array can hold any item (i.e. object dtype), then use pytest.skip.
return object.__new__(object)
class TestDtype(base.BaseDtypeTests):
class TestInterface(base.BaseInterfaceTests):
class TestConstructors(base.BaseConstructorsTests):
class TestGetitem(base.BaseGetitemTests):
class TestSetitem(base.BaseSetitemTests):
class TestIndex(base.BaseIndexTests):
class TestMissing(base.BaseMissingTests):
def test_to_pandas():
schema = pa.schema([("data", ObjectIdType())])
table = pa.Table.from_pydict(
{ "data": [ObjectId().binary, ObjectId().binary, ObjectId().binary, None]},
df = table.to_pandas()
import pdb; pdb.set_trace()
Copy link

blink1073 commented Jan 14, 2023

Rollout plan:

  • Switch to pytest so we can use the pandas fixtures
  • Implement Binary extension types
  • Implement ObjectId pandas extension type
  • Implement Decimal128 as binary(16), as_py as Decimal128 and Pandas Decimal128 and deprecate Decimal128Str
  • Implement the other types: DBRef, Code, Int64, MaxKey, MinKey, Regex, Timestamp
  • Consider using the ArrowExtensionArray as the base class once it becomes stable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment