Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active March 27, 2019 20:04
Show Gist options
  • Save shoyer/36b84ab064f027df318c0b823558de24 to your computer and use it in GitHub Desktop.
Save shoyer/36b84ab064f027df318c0b823558de24 to your computer and use it in GitHub Desktop.
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import numpy as np
def _binary_method(ufunc):
def func(self, other):
try:
if other.__array_ufunc__ is None:
return NotImplemented
except AttributeError:
pass
return self.__array_ufunc__(ufunc, '__call__', self, other)
return func
def _reflected_binary_method(ufunc):
def func(self, other):
return self.__array_ufunc__(ufunc, '__call__', other, self)
return func
def _inplace_binary_method(ufunc):
def func(self, other):
result = self.__array_ufunc__(
ufunc, '__call__', self, other, out=(self,))
if result is NotImplemented:
raise TypeError('unsupported operand types for in-place '
'arithmetic: %s and %s'
% (type(self).__name__, type(other).__name__))
return result
return func
def _numeric_methods(ufunc):
return (_binary_method(ufunc),
_reflected_binary_method(ufunc),
_inplace_binary_method(ufunc))
def _unary_method(ufunc):
def func(self):
return self.__array_ufunc__(ufunc, '__call__', self)
return func
class UFuncSpecialMethodMixin(object):
"""Implements all special methods using __array_ufunc__."""
# comparisons
__lt__ = _binary_method(np.less)
__le__ = _binary_method(np.less_equal)
__eq__ = _binary_method(np.equal)
__ne__ = _binary_method(np.not_equal)
__gt__ = _binary_method(np.greater)
__ge__ = _binary_method(np.greater_equal)
# numeric methods
__add__, __radd__, __iadd__ = _numeric_methods(np.add)
__sub__, __rsub__, __isub__ = _numeric_methods(np.subtract)
__mul__, __rmul__, __imul__ = _numeric_methods(np.multiply)
__matmul__, __rmatmul__, __imatmul__ = _numeric_methods(np.matmul)
__div__, __rdiv__, __idiv__ = _numeric_methods(np.divide) # Python 2 only
__truediv__, __rtruediv__, __itruediv__ = _numeric_methods(np.true_divide)
__floordiv__, __rfloordiv__, __ifloordiv__ = _numeric_methods(np.floor_divide)
__mod__, __rmod__, __imod__ = _numeric_methods(np.mod)
# No ufunc for __divmod__!
# TODO: handle the optional third argument for __pow__?
__pow__, __rpow__, __ipow__ = _numeric_methods(np.power)
__lshift__, __rlshift__, __ilshift__ = _numeric_methods(np.left_shift)
__rshift__, __rrshift__, __irshift__ = _numeric_methods(np.right_shift)
__and__, __rand__, __iand__ = _numeric_methods(np.logical_and)
__xor__, __rxor__, __ixor__ = _numeric_methods(np.logical_xor)
__or__, __ror__, __ior__ = _numeric_methods(np.logical_or)
# unary methods
__neg__ =_unary_method(np.negative)
# No ufunc for __pos__!
__abs__ = _unary_method(np.absolute)
__invert__ = _unary_method(np.invert)
class ArrayLike(UFuncSpecialMethodMixin):
"""An array-like class that wraps NumPy arrays.
Example usage:
>>> x = ArrayLike([1, 2, 3])
>>> x - 1
ArrayLike(array([0, 1, 2]))
>>> 1 - x
ArrayLike(array([ 0, -1, -2]))
>>> np.arange(3) - x
ArrayLike(array([-1, -1, -1]))
>>> x - np.arange(3)
ArrayLike(array([1, 1, 1]))
"""
def __init__(self, value):
self.value = np.asarray(value)
__array_priority__ = 1000 # for legacy reasons
_handled_types = (np.ndarray, numbers.Number)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# ArrayLike implements arithmetic and ufuncs by deferring to the wrapped array
for x in inputs:
# handle _handled_types and superclass instances
if not (isinstance(x, self._handled_types)
or isinstance(self, type(x))):
return NotImplemented
inputs = tuple(x.value if isinstance(self, type(x)) else x
for x in inputs)
if kwargs.get('out') is not None:
kwargs['out'] = tuple(x.value if isinstance(self, type(x)) else x
for x in kwargs['out'])
result = getattr(ufunc, method)(*inputs, **kwargs)
if isinstance(result, tuple):
return tuple(type(self)(x) for x in result)
else:
return type(self)(result)
def __repr__(self):
return '%s(%r)' % (type(self).__name__, self.value)
class OptOut(object):
__array_ufunc__ = None
def __add__(self, other):
return '__add__'
def __radd__(self, other):
return '__radd__'
class SubArrayLike(ArrayLike):
pass
# TODO: test that all combinations of arithmetic between number, np.ndarray,
# ArrayLike, SubArrayLike and OptOut yield the correct result
@mhvk
Copy link

mhvk commented Apr 3, 2017

@shoyer - I like this very much! I also think we should get it in together with __array_ufunc__ -- maybe the easiest is to actually make a PR to @charris's branch?

One thing would be the name: NDArrayOperators might be another example. At some level, I wonder if it is possible to actually use this in ndarray itself (i.e., it would become a class that inherits from both this mixin and from a base class). Obviously for later, though!!

@mhvk
Copy link

mhvk commented Apr 3, 2017

Given the discussion with @pv about what to handle, etc., I wonder whether we should expose the functions in binop_override.h...

In this respect, in earlier discussions we had our example array class have a specific _can_handle method -- I think this is better as it is easier to override. (and it is somewhat in analogy with a ShapedLikeNDArray mixin class I wrote for astropy, where I pass all reshape, transpose, etc., methods through a single _apply method; see https://github.com/astropy/astropy/blob/master/astropy/utils/misc.py#L856

@shoyer
Copy link
Author

shoyer commented Apr 3, 2017

I like NDArrayOperatorsMixin for the name. Let's know if you have any suggestions about where these files should end up in numpy... I'll look into putting together a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment