Skip to content

Instantly share code, notes, and snippets.

@Finndersen
Last active January 17, 2023 17:46
Show Gist options
  • Save Finndersen/353747470f375c7c17147214f357af7a to your computer and use it in GitHub Desktop.
Save Finndersen/353747470f375c7c17147214f357af7a to your computer and use it in GitHub Desktop.
Definition of VectorArray for Pandas Extension Types example
class VectorArray(ExtensionScalarOpsMixin, ExtensionArray):
"""
Custom Extension Array type for an array of Vectors
Needs to define:
- Associated Dtype it is used with
- How to construct array from sequence of scalars
- How data is stored and accessed
- Any custom array methods
"""
def __init__(self, x_values, y_values, copy=False):
"""
Initialise array of vectors from component X and Y values
(Allows efficient initialisation from existing lists/arrays)
:param x_values: Sequence/array of vector x-component values
:param y_values: Sequence/array of vector y-component values
"""
self.x_values = np.array(x_values, dtype=np.float64, copy=copy)
self.y_values = np.array(y_values, dtype=np.float64, copy=copy)
@classmethod
def _from_sequence(cls, scalars, *, dtype=None, copy=False):
"""
Construct a new ExtensionArray from a sequence of scalars.
Each element will be an instance of the scalar type for this array,
or be converted into this type in this method.
"""
# Construct new array from sequence of values (Unzip vectors into x and y components)
x_values, y_values = zip(*[create_vector(val).as_tuple() for val in scalars])
return VectorArray(x_values, y_values, copy=copy)
@classmethod
def from_vectors(cls, vectors):
"""
Construct array from sequence of values (vectors)
Can be provided as Vector instances or list/tuple like (x, y) pairs
"""
return cls._from_sequence(vectors)
@classmethod
def _concat_same_type(cls, to_concat):
"""
Concatenate multiple arrays of this dtype
"""
return VectorArray(
np.concatenate(arr.x_values for arr in to_concat),
np.concatenate(arr.y_values for arr in to_concat),
)
@property
def dtype(self):
"""
Return Dtype instance (not class) associated with this Array
"""
return VectorDtype()
@property
def nbytes(self):
"""
The number of bytes needed to store this object in memory.
"""
return self.x_values.nbytes + self.y_values.nbytes
def __getitem__(self, item):
"""
Retrieve single item or slice
"""
if isinstance(item, int):
# Get single vector
return Vector(self.x_values[item], self.y_values[item])
else:
# Get subset from slice or boolean array
return VectorArray(self.x_values[item], self.y_values[item])
def __eq__(self, other):
"""
Perform element-wise equality with a given vector value
"""
if isinstance(other, (pd.Index, pd.Series, pd.DataFrame)):
return NotImplemented
return (self.x_values == other[0]) & (self.y_values == other[1])
def __len__(self):
return self.x_values.size
def isna(self):
"""
Returns a 1-D array indicating if each value is missing
"""
return np.isnan(self.x_values)
def take(self, indices, *, allow_fill=False, fill_value=None):
"""
Take element from array using positional indexing
"""
from pandas.core.algorithms import take
if allow_fill and fill_value is None:
fill_value = self.dtype.na_value
x_result = take(self.x_values, indices, fill_value=fill_value, allow_fill=allow_fill)
y_result = take(self.y_values, indices, fill_value=fill_value, allow_fill=allow_fill)
return VectorArray(x_result, y_result)
def copy(self):
"""
Return copy of array
"""
return VectorArray(np.copy(self.x_values), np.copy(self.y_values))
def magnitude(self):
"""
Return array of magnitude values for vectors.
"""
# Implement using NumPy vectorised functions for efficiency
return np.sqrt(np.square(self.x_values) + np.square(self.y_values))
def dot(self, other):
"""
Calculate dot product with single Vector or VectorArray of same length
"""
if isinstance(other, Vector):
# Dot product with single Vector
return self.x_values*other.x + self.y_values*other.y
elif isinstance(other, VectorArray) and self.size == other.size:
# Element-wise dot product with other VectorArray
return self.x_values*other.x_values + self.y_values*other.y_values
else:
raise TypeError('Cannot perform dot product with {}'.format(other))
# Register operator overloads using logic defined in Vector class
VectorArray._add_arithmetic_ops()
VectorArray._add_comparison_ops()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment