Last active
December 21, 2020 05:56
-
-
Save Eastsun/a59fb0438f65e8643cd61d8c98ec4c08 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from __future__ import annotations | |
from pandas.api.extensions import ExtensionDtype, ExtensionArray | |
from typing import Optional, Union, Dict, Tuple, Any, Sequence | |
import numpy as np | |
import re | |
class VectorDtype(ExtensionDtype): | |
_metadata = 'subtype', 'size' | |
_cache: Dict[Tuple[np.dtype, int], VectorDtype] = {} | |
_pattern = re.compile(r'[Vv]ector\[(?P<subtype>[^\s,]+),\s*(?P<size>\d+)\]') | |
def __new__(cls, subtype: Union[VectorDtype, np.dtype], size: Optional[int] = None) -> VectorDtype: | |
if isinstance(subtype, VectorDtype): | |
return subtype | |
if (subtype, size) not in cls._cache: | |
instance = super().__new__(cls) | |
instance._subtype = subtype | |
instance._size = size | |
cls._cache[subtype, size] = instance | |
return cls._cache[subtype, size] | |
@classmethod | |
def construct_from_string(cls, string: str) -> VectorDtype: | |
if not (isinstance(string, str) and cls._pattern.fullmatch(string) is not None): | |
raise TypeError(f'Could not construct VectorDtype from: {string}') | |
m = cls._pattern.fullmatch(string) | |
return cls(subtype=np.dtype(m.group('subtype')), size=int(m.group('size'))) | |
@classmethod | |
def construct_array_type(cls): | |
return VectorArray | |
@property | |
def type(self) -> type: | |
return np.ndarray | |
@property | |
def kind(self) -> str: | |
return 'O' | |
@property | |
def name(self) -> str: | |
return f'vector[{self.subtype}, {self.size}]' | |
@property | |
def subtype(self) -> np.dtype: | |
return self._subtype | |
@property | |
def size(self) -> int: | |
return self._size | |
@property | |
def na_value(self) -> np.ndarray: | |
if np.issubdtype(self.subtype, np.floating): | |
return np.full(self.size, fill_value=np.nan, dtype=self.subtype) | |
return np.full(self.size, fill_value=-1, dtype=self.subtype) | |
class VectorArray(ExtensionArray): | |
def __init__(self, data: np.ndarray, dtype: Union[str, VectorDtype] = None, copy: bool = False): | |
if not (isinstance(data, np.ndarray) and data.ndim == 2): | |
raise ValueError('Need a 2D array') | |
if dtype is None: | |
self._dtype = VectorDtype(data.dtype, data.shape[1]) | |
elif isinstance(dtype, str): | |
self._dtype = VectorDtype.construct_from_string(dtype) | |
elif isinstance(dtype, VectorDtype): | |
self._dtype = dtype | |
else: | |
raise ValueError(f'Un-compatible dtype of VectorArray: {dtype}') | |
if not (data.dtype == self.dtype.subtype and data.shape[1] == self.dtype.size): | |
raise ValueError(f'Un-compatible dtype: {self.dtype} vs ({data.dtype}, {data.shape})') | |
self._data = data.copy() if copy else data | |
def copy(self, deep: bool = False) -> VectorArray: | |
return VectorArray(self._data, dtype=self.dtype, copy=True) | |
@classmethod | |
def _from_sequence(cls, scalars, dtype: Optional[VectorDtype] = None, copy: bool = False) -> VectorArray: | |
data = np.array(scalars, dtype=dtype if dtype is None else dtype.subtype, copy=copy) | |
return VectorArray(data, dtype=dtype, copy=copy) | |
@classmethod | |
def _concat_same_type(cls, to_concat: Sequence[VectorArray]) -> VectorArray: | |
data = np.concatenate([va._data for va in to_concat]) | |
return cls(data, copy=False) | |
@property | |
def dtype(self) -> VectorDtype: | |
return self._dtype | |
@property | |
def nbytes(self) -> int: | |
return self._data.nbytes | |
def isna(self) -> np.ndarray: | |
if np.issubdtype(self.dtype.subtype, np.floating): | |
return np.isnan(self._data).any(axis=1) | |
return (self._data < 0).any(axis=1) | |
def __setitem__(self, key: Union[int, slice, np.ndarray], value: Any) -> None: | |
raise NotImplementedError() | |
def __getitem__(self, item: Union[int, slice, np.ndarray]) -> Union[np.ndarray, VectorArray]: | |
if isinstance(item, int): | |
return self._data[item] | |
return VectorArray(self._data[item, :], dtype=self.dtype, copy=False) | |
def take(self, indices: Sequence[int], allow_fill: bool = False, fill_value: bool = None) -> VectorArray: | |
indices = np.asarray(indices, dtype='int') | |
if allow_fill: | |
fill_value = self.dtype.na_value if fill_value is None else np.asarray(fill_value, dtype=self.dtype.subtype) | |
mask = (indices == -1) | |
if (indices < -1).any(): | |
raise ValueError('Invalid value in `indices`, must be all >= -1 for `allow_fill` is True') | |
elif len(self) > 0: | |
pass | |
elif not np.all(mask): | |
raise IndexError('Invalid take for empty VectorArray, must be all -1.') | |
else: | |
data = np.array([fill_value] * len(indices), dtype=self.dtype.subtype) | |
return VectorArray(data, dtype=self.dtype, copy=False) | |
took = self._data.take(indices, axis=0) | |
if allow_fill and mask.any(): | |
took[mask] = [fill_value] * np.sum(mask) | |
return VectorArray(took, dtype=self.dtype, copy=False) | |
def __len__(self) -> int: | |
return len(self._data) | |
def __eq__(self, other) -> np.ndarray: | |
if not isinstance(other, VectorArray): | |
return NotImplemented | |
if self.dtype.size != other.dtype.size: | |
raise ValueError(f'The size of two VectorArray does not compatible') | |
isna = self.isna() | other.isna() | |
iseq = (self._data == other._data).all(axis=1) | |
return iseq & ~ isna |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment