Skip to content

Instantly share code, notes, and snippets.

@Eastsun
Last active December 21, 2020 05:56
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Eastsun/a59fb0438f65e8643cd61d8c98ec4c08 to your computer and use it in GitHub Desktop.
Save Eastsun/a59fb0438f65e8643cd61d8c98ec4c08 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
from __future__ import annotations
from pandas.api.extensions import ExtensionDtype, ExtensionArray
from typing import Optional, Union, Dict, Tuple, Any, Sequence
import numpy as np
import re
class VectorDtype(ExtensionDtype):
_metadata = 'subtype', 'size'
_cache: Dict[Tuple[np.dtype, int], VectorDtype] = {}
_pattern = re.compile(r'[Vv]ector\[(?P<subtype>[^\s,]+),\s*(?P<size>\d+)\]')
def __new__(cls, subtype: Union[VectorDtype, np.dtype], size: Optional[int] = None) -> VectorDtype:
if isinstance(subtype, VectorDtype):
return subtype
if (subtype, size) not in cls._cache:
instance = super().__new__(cls)
instance._subtype = subtype
instance._size = size
cls._cache[subtype, size] = instance
return cls._cache[subtype, size]
@classmethod
def construct_from_string(cls, string: str) -> VectorDtype:
if not (isinstance(string, str) and cls._pattern.fullmatch(string) is not None):
raise TypeError(f'Could not construct VectorDtype from: {string}')
m = cls._pattern.fullmatch(string)
return cls(subtype=np.dtype(m.group('subtype')), size=int(m.group('size')))
@classmethod
def construct_array_type(cls):
return VectorArray
@property
def type(self) -> type:
return np.ndarray
@property
def kind(self) -> str:
return 'O'
@property
def name(self) -> str:
return f'vector[{self.subtype}, {self.size}]'
@property
def subtype(self) -> np.dtype:
return self._subtype
@property
def size(self) -> int:
return self._size
@property
def na_value(self) -> np.ndarray:
if np.issubdtype(self.subtype, np.floating):
return np.full(self.size, fill_value=np.nan, dtype=self.subtype)
return np.full(self.size, fill_value=-1, dtype=self.subtype)
class VectorArray(ExtensionArray):
def __init__(self, data: np.ndarray, dtype: Union[str, VectorDtype] = None, copy: bool = False):
if not (isinstance(data, np.ndarray) and data.ndim == 2):
raise ValueError('Need a 2D array')
if dtype is None:
self._dtype = VectorDtype(data.dtype, data.shape[1])
elif isinstance(dtype, str):
self._dtype = VectorDtype.construct_from_string(dtype)
elif isinstance(dtype, VectorDtype):
self._dtype = dtype
else:
raise ValueError(f'Un-compatible dtype of VectorArray: {dtype}')
if not (data.dtype == self.dtype.subtype and data.shape[1] == self.dtype.size):
raise ValueError(f'Un-compatible dtype: {self.dtype} vs ({data.dtype}, {data.shape})')
self._data = data.copy() if copy else data
def copy(self, deep: bool = False) -> VectorArray:
return VectorArray(self._data, dtype=self.dtype, copy=True)
@classmethod
def _from_sequence(cls, scalars, dtype: Optional[VectorDtype] = None, copy: bool = False) -> VectorArray:
data = np.array(scalars, dtype=dtype if dtype is None else dtype.subtype, copy=copy)
return VectorArray(data, dtype=dtype, copy=copy)
@classmethod
def _concat_same_type(cls, to_concat: Sequence[VectorArray]) -> VectorArray:
data = np.concatenate([va._data for va in to_concat])
return cls(data, copy=False)
@property
def dtype(self) -> VectorDtype:
return self._dtype
@property
def nbytes(self) -> int:
return self._data.nbytes
def isna(self) -> np.ndarray:
if np.issubdtype(self.dtype.subtype, np.floating):
return np.isnan(self._data).any(axis=1)
return (self._data < 0).any(axis=1)
def __setitem__(self, key: Union[int, slice, np.ndarray], value: Any) -> None:
raise NotImplementedError()
def __getitem__(self, item: Union[int, slice, np.ndarray]) -> Union[np.ndarray, VectorArray]:
if isinstance(item, int):
return self._data[item]
return VectorArray(self._data[item, :], dtype=self.dtype, copy=False)
def take(self, indices: Sequence[int], allow_fill: bool = False, fill_value: bool = None) -> VectorArray:
indices = np.asarray(indices, dtype='int')
if allow_fill:
fill_value = self.dtype.na_value if fill_value is None else np.asarray(fill_value, dtype=self.dtype.subtype)
mask = (indices == -1)
if (indices < -1).any():
raise ValueError('Invalid value in `indices`, must be all >= -1 for `allow_fill` is True')
elif len(self) > 0:
pass
elif not np.all(mask):
raise IndexError('Invalid take for empty VectorArray, must be all -1.')
else:
data = np.array([fill_value] * len(indices), dtype=self.dtype.subtype)
return VectorArray(data, dtype=self.dtype, copy=False)
took = self._data.take(indices, axis=0)
if allow_fill and mask.any():
took[mask] = [fill_value] * np.sum(mask)
return VectorArray(took, dtype=self.dtype, copy=False)
def __len__(self) -> int:
return len(self._data)
def __eq__(self, other) -> np.ndarray:
if not isinstance(other, VectorArray):
return NotImplemented
if self.dtype.size != other.dtype.size:
raise ValueError(f'The size of two VectorArray does not compatible')
isna = self.isna() | other.isna()
iseq = (self._data == other._data).all(axis=1)
return iseq & ~ isna
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment