Created
February 29, 2020 20:25
-
-
Save kcarnold/1908bb78ae6660fd983a7907685b73d8 to your computer and use it in GitHub Desktop.
VecPile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VecPile: | |
"""An attribute-accesed collection that asserts that all of its elements have the same length. | |
Useful for keeping several collections together, such as vectors with labels, or several different representations of the same data.""" | |
def __init__(self, **kw): | |
for k, v in kw.items(): | |
setattr(self, k, v) | |
@staticmethod | |
def get_len(x): | |
try: | |
return x.shape[0] | |
except AttributeError: | |
return len(x) | |
def __setattr__(self, key, value): | |
new_len = self.get_len(value) | |
for existing in self.__dict__.values(): | |
existing_len = self.get_len(existing) | |
if existing_len != new_len: | |
raise ValueError( | |
f"Dimension mismatch: vecpile has dimension {existing_len} but trying to add a {new_len}" | |
) | |
self.__dict__[key] = value | |
def __len__(self): | |
for existing in self.__dict__.values(): | |
return self.get_len(existing) | |
def test_vecpile(): | |
vp = VecPile() | |
x = np.zeros(10) | |
vp.x = x | |
assert vp.x is x | |
try: | |
vp.y = np.zeros(2) | |
assert False, "Should have failed." | |
except ValueError: | |
pass | |
vp = VecPile(x=x) | |
assert vp.x is x | |
assert len(vp) == len(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment