Skip to content

Instantly share code, notes, and snippets.

@MischaPanch
Last active August 23, 2021 10:46
Show Gist options
  • Save MischaPanch/335f075aa00f36fdb0e35afe4d7161de to your computer and use it in GitHub Desktop.
Save MischaPanch/335f075aa00f36fdb0e35afe4d7161de to your computer and use it in GitHub Desktop.
Base class with attribute and dict like access, with IDE support for attributes and keys
import json
from abc import ABC, abstractmethod
from copy import copy
from enum import Enum
from typing import Generic, List, Type, TypeVar, get_args
import numpy as np
class KeyEnum(str, Enum):
pass
TKeyEnum = TypeVar("TKeyEnum", bound=KeyEnum)
T = TypeVar("T")
class AttrAndDictAccess(ABC, Generic[TKeyEnum]):
"""
Allows both attribute and dict-like access. Useful for dataclasses where attributes need to be mutated
after instantiation (e.g in loops over keys) while still having IDE support for autocompletion and type
inspection with "dotted-access". A second class enumerating the attribute names
has to be introduced and passed as generic (this allows dict-like access).
Subclasses should be annotated with '@dataclass(order=True)',
we recommend using pydantic for better support of serialization.
A basic implementation will look as following:
>>> from datastruct import AttrAndDictAccess
>>> from pydantic.dataclasses import dataclass
>>> class CustomKeys(KeyEnum):
... key1 = "key1"
... key2 = "key2"
...
>>> @dataclass(order=True)
... class SemanticContainer(AttrAndDictAccess[CustomKeys]):
... key1: float = 3
... key2: str = "second_key"
...
... def _check_input(self, key, value):
... pass
...
>>> container = SemanticContainer(key2="custom")
>>> SemanticContainer.keys()[0] == CustomKeys.key1 == "key1"
True
>>> container.key1 == container["key1"] == container[CustomKeys.key1] == 3
True
>>> container.key2
'custom'
"""
_FLOAT_REPRESENTATION_ACCURACY = 6
def key_enum(cls) -> TKeyEnum:
return get_args(cls.__orig_bases__[0])[0]
def keys(cls) -> List[str]:
return [key.value for key in cls.key_enum()]
def __setattr__(self, key: str, value):
self._check_input(key, value)
super().__setattr__(key, value)
@abstractmethod
def _check_input(self, key, value):
"""
Raise exception if input is invalid for the selected attribute
"""
pass
def __setitem__(self, key, value):
if key not in self.keys():
raise KeyError(f"Invalid key {key}. Valid keys are: {self.keys()}")
setattr(self, key, value)
def __getitem__(self, item):
if item not in self.keys():
raise KeyError(f"Invalid key {item}. Valid keys are: {self.keys()}")
return getattr(self, item)
def __len__(self):
return len(self.keys())
def __post_init__(self):
"""
This is a sanity check of the sort-of hacky implementation of the stretch between autocompletion
of parameter names, named access to attributes and support for vectorized operations.
:return:
"""
if self.keys() != list(vars(self)):
raise AttributeError(
f"Wrong Implementation of {self.__class__.__name__}: "
f"fields need to coincide in value and order with the KeyEnum used"
f"when defining the class"
)
def __repr__(self):
def maybe_round(val):
if isinstance(val, float):
val = round(val, self._FLOAT_REPRESENTATION_ACCURACY)
return val
return "_".join(f"{k}_{maybe_round(v)}" for k, v in self.to_dict().items())
def to_array(self) -> np.ndarray:
return np.array([self.__dict__[key] for key in self.keys()])
@classmethod
def from_array(cls: Type[T], arr: np.ndarray) -> T:
return cls(*arr)
def print(self):
print(json.dumps(self.to_dict(), indent=4))
def to_dict(self):
key_dict = copy(self.__dict__)
key_dict.pop("__initialised__", None) # needed b/c of dataclass specifics
return key_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment