Skip to content

Instantly share code, notes, and snippets.

@jamescasbon
Created March 31, 2019 11:20
Show Gist options
  • Save jamescasbon/b0e1f2113a28e523ff3326d7b93eda19 to your computer and use it in GitHub Desktop.
Save jamescasbon/b0e1f2113a28e523ff3326d7b93eda19 to your computer and use it in GitHub Desktop.
Evil monkeypatch for numpy cmp with attrs
import attr
import numpy as np
import attr._make
original_make_cmp = attr._make._make_cmp
def _is_np_attr(x):
return x.type == np.ndarray
def numpy_make_cmp(attrs):
np_attrs = [a for a in attrs if _is_np_attr(a)]
if not np_attrs:
return original_make_cmp(attrs)
other_attrs = [a for a in attrs if not _is_np_attr(a)]
eq, ne, lt, le, gt, ge = original_make_cmp(other_attrs)
np_eqs = [
lambda x, y: np.array_equal(getattr(x, a.name), getattr(y, a.name))
for a in np_attrs
]
def __eq__(self, other):
return eq(self, other) and all(e(self, other) for e in np_eqs)
def __ne__(self, other):
return not __eq__(self, other)
return __eq__, __ne__, None, None, None, None
attr._make._make_cmp = numpy_make_cmp
@attr.s(auto_attribs=True)
class C:
x: np.ndarray
y: int
c1 = C(x=np.array([1, 2]), y=1)
c2 = C(x=np.array([1, 2]), y=1)
assert c1 == c2
c2 = C(x=np.array([1, 2]), y=2)
assert c1 != c2
c2 = C(x=np.array([1, 3]), y=1)
assert c1 != c2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment