Last active
March 27, 2019 20:04
-
-
Save shoyer/36b84ab064f027df318c0b823558de24 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2017 Google Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import numbers | |
import numpy as np | |
def _binary_method(ufunc): | |
def func(self, other): | |
try: | |
if other.__array_ufunc__ is None: | |
return NotImplemented | |
except AttributeError: | |
pass | |
return self.__array_ufunc__(ufunc, '__call__', self, other) | |
return func | |
def _reflected_binary_method(ufunc): | |
def func(self, other): | |
return self.__array_ufunc__(ufunc, '__call__', other, self) | |
return func | |
def _inplace_binary_method(ufunc): | |
def func(self, other): | |
result = self.__array_ufunc__( | |
ufunc, '__call__', self, other, out=(self,)) | |
if result is NotImplemented: | |
raise TypeError('unsupported operand types for in-place ' | |
'arithmetic: %s and %s' | |
% (type(self).__name__, type(other).__name__)) | |
return result | |
return func | |
def _numeric_methods(ufunc): | |
return (_binary_method(ufunc), | |
_reflected_binary_method(ufunc), | |
_inplace_binary_method(ufunc)) | |
def _unary_method(ufunc): | |
def func(self): | |
return self.__array_ufunc__(ufunc, '__call__', self) | |
return func | |
class UFuncSpecialMethodMixin(object): | |
"""Implements all special methods using __array_ufunc__.""" | |
# comparisons | |
__lt__ = _binary_method(np.less) | |
__le__ = _binary_method(np.less_equal) | |
__eq__ = _binary_method(np.equal) | |
__ne__ = _binary_method(np.not_equal) | |
__gt__ = _binary_method(np.greater) | |
__ge__ = _binary_method(np.greater_equal) | |
# numeric methods | |
__add__, __radd__, __iadd__ = _numeric_methods(np.add) | |
__sub__, __rsub__, __isub__ = _numeric_methods(np.subtract) | |
__mul__, __rmul__, __imul__ = _numeric_methods(np.multiply) | |
__matmul__, __rmatmul__, __imatmul__ = _numeric_methods(np.matmul) | |
__div__, __rdiv__, __idiv__ = _numeric_methods(np.divide) # Python 2 only | |
__truediv__, __rtruediv__, __itruediv__ = _numeric_methods(np.true_divide) | |
__floordiv__, __rfloordiv__, __ifloordiv__ = _numeric_methods(np.floor_divide) | |
__mod__, __rmod__, __imod__ = _numeric_methods(np.mod) | |
# No ufunc for __divmod__! | |
# TODO: handle the optional third argument for __pow__? | |
__pow__, __rpow__, __ipow__ = _numeric_methods(np.power) | |
__lshift__, __rlshift__, __ilshift__ = _numeric_methods(np.left_shift) | |
__rshift__, __rrshift__, __irshift__ = _numeric_methods(np.right_shift) | |
__and__, __rand__, __iand__ = _numeric_methods(np.logical_and) | |
__xor__, __rxor__, __ixor__ = _numeric_methods(np.logical_xor) | |
__or__, __ror__, __ior__ = _numeric_methods(np.logical_or) | |
# unary methods | |
__neg__ =_unary_method(np.negative) | |
# No ufunc for __pos__! | |
__abs__ = _unary_method(np.absolute) | |
__invert__ = _unary_method(np.invert) | |
class ArrayLike(UFuncSpecialMethodMixin): | |
"""An array-like class that wraps NumPy arrays. | |
Example usage: | |
>>> x = ArrayLike([1, 2, 3]) | |
>>> x - 1 | |
ArrayLike(array([0, 1, 2])) | |
>>> 1 - x | |
ArrayLike(array([ 0, -1, -2])) | |
>>> np.arange(3) - x | |
ArrayLike(array([-1, -1, -1])) | |
>>> x - np.arange(3) | |
ArrayLike(array([1, 1, 1])) | |
""" | |
def __init__(self, value): | |
self.value = np.asarray(value) | |
__array_priority__ = 1000 # for legacy reasons | |
_handled_types = (np.ndarray, numbers.Number) | |
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): | |
# ArrayLike implements arithmetic and ufuncs by deferring to the wrapped array | |
for x in inputs: | |
# handle _handled_types and superclass instances | |
if not (isinstance(x, self._handled_types) | |
or isinstance(self, type(x))): | |
return NotImplemented | |
inputs = tuple(x.value if isinstance(self, type(x)) else x | |
for x in inputs) | |
if kwargs.get('out') is not None: | |
kwargs['out'] = tuple(x.value if isinstance(self, type(x)) else x | |
for x in kwargs['out']) | |
result = getattr(ufunc, method)(*inputs, **kwargs) | |
if isinstance(result, tuple): | |
return tuple(type(self)(x) for x in result) | |
else: | |
return type(self)(result) | |
def __repr__(self): | |
return '%s(%r)' % (type(self).__name__, self.value) | |
class OptOut(object): | |
__array_ufunc__ = None | |
def __add__(self, other): | |
return '__add__' | |
def __radd__(self, other): | |
return '__radd__' | |
class SubArrayLike(ArrayLike): | |
pass | |
# TODO: test that all combinations of arithmetic between number, np.ndarray, | |
# ArrayLike, SubArrayLike and OptOut yield the correct result |
I like NDArrayOperatorsMixin
for the name. Let's know if you have any suggestions about where these files should end up in numpy... I'll look into putting together a PR.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Given the discussion with @pv about what to handle, etc., I wonder whether we should expose the functions in
binop_override.h
...In this respect, in earlier discussions we had our example array class have a specific
_can_handle
method -- I think this is better as it is easier to override. (and it is somewhat in analogy with aShapedLikeNDArray
mixin class I wrote forastropy
, where I pass allreshape
,transpose
, etc., methods through a single_apply
method; see https://github.com/astropy/astropy/blob/master/astropy/utils/misc.py#L856