Skip to content

Instantly share code, notes, and snippets.

@kcarnold
Created February 29, 2020 20:25
Show Gist options
  • Save kcarnold/1908bb78ae6660fd983a7907685b73d8 to your computer and use it in GitHub Desktop.
Save kcarnold/1908bb78ae6660fd983a7907685b73d8 to your computer and use it in GitHub Desktop.
VecPile
class VecPile:
"""An attribute-accesed collection that asserts that all of its elements have the same length.
Useful for keeping several collections together, such as vectors with labels, or several different representations of the same data."""
def __init__(self, **kw):
for k, v in kw.items():
setattr(self, k, v)
@staticmethod
def get_len(x):
try:
return x.shape[0]
except AttributeError:
return len(x)
def __setattr__(self, key, value):
new_len = self.get_len(value)
for existing in self.__dict__.values():
existing_len = self.get_len(existing)
if existing_len != new_len:
raise ValueError(
f"Dimension mismatch: vecpile has dimension {existing_len} but trying to add a {new_len}"
)
self.__dict__[key] = value
def __len__(self):
for existing in self.__dict__.values():
return self.get_len(existing)
def test_vecpile():
vp = VecPile()
x = np.zeros(10)
vp.x = x
assert vp.x is x
try:
vp.y = np.zeros(2)
assert False, "Should have failed."
except ValueError:
pass
vp = VecPile(x=x)
assert vp.x is x
assert len(vp) == len(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment