Skip to content

Instantly share code, notes, and snippets.

@malcolmgreaves
Created September 5, 2019 18:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save malcolmgreaves/37e3172b7942d7e2010059410b2586a8 to your computer and use it in GitHub Desktop.
Save malcolmgreaves/37e3172b7942d7e2010059410b2586a8 to your computer and use it in GitHub Desktop.
Class for registering transformation functions alongside the values of an enum.
from enum import Enum
from functools import lru_cache
from typing import Dict, Sequence, Any, Callable, Iterable, Tuple, Type, Generic, TypeVar, Union
E = TypeVar("E", bound=Enum)
Transformer = Callable[[Any], Any]
def name_of(enum_or_name: Union[E, str]) -> str:
if isinstance(enum_or_name, str):
return enum_or_name
else:
return enum_or_name.name
class EnumCallable(Generic[E]):
def __init__(
self,
enum_type: Type[E],
enum_func_registry: Union[
Iterable[Tuple[Union[E, str], Transformer]],
Dict[Union[E, str], Transformer],
],
validate: bool = True,
) -> None:
self._enum_type = enum_type
if validate and not issubclass(self._enum_type, Enum):
raise TypeError(f"Expecting an enum type, not '{self._enum_type}'")
enum2func = (
enum_func_registry.items()
if isinstance(enum_func_registry, dict)
else enum_func_registry
)
self._enum2func = {name_of(enum_val): func for enum_val, func in enum2func}
if validate:
vals = set(map(lambda e: e.name, self.values))
for enum_name in self._enum2func.keys():
if enum_name not in vals:
raise ValueError(f"No function for enum value named: '{enum_name}'")
@property
@lru_cache(1)
def values(self) -> tuple:
"""Obtain all distinct enum values.
"""
return tuple(self._enum_type._member_map_.values())
def get_transformer(self, e: E) -> Callable[[Any], Any]:
"""Obtain the registered transformation function for the specific enum value.
"""
return self._enum2func[e.name]
def transform(self, enum_val_or_name: Union[E, str], value: Any) -> Any:
"""Applies the function registered for the enum's specific value.
"""
return self._enum2func[name_of(enum_val_or_name)](value)
def __call__(self, *args, **kwargs):
"""Alis for :func:`transform`.
"""
return self.transform(*args, **kwargs)
def test_enum_callable():
ET = Enum("ET", ["a", "b", "c"])
ec = EnumCallable(
enum_type=ET,
enum_func_registry={
ET["a"]: lambda x: x + 1,
"b": lambda x: x % 2 == 0,
ET["c"]: lambda x: x * 10,
},
validate=True,
)
assert ec.transform(ET["a"], 0) == 1
assert ec.transform(ET["b"], 6)
assert ec.transform(ET["c"], -10) == -100
assert ec.transform("a", -1) == 0
assert ec.transform("b", 2)
assert ec.transform("c", 99) == 990
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment