Created
September 5, 2019 18:52
-
-
Save malcolmgreaves/37e3172b7942d7e2010059410b2586a8 to your computer and use it in GitHub Desktop.
Class for registering transformation functions alongside the values of an enum.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
from functools import lru_cache | |
from typing import Dict, Sequence, Any, Callable, Iterable, Tuple, Type, Generic, TypeVar, Union | |
E = TypeVar("E", bound=Enum) | |
Transformer = Callable[[Any], Any] | |
def name_of(enum_or_name: Union[E, str]) -> str: | |
if isinstance(enum_or_name, str): | |
return enum_or_name | |
else: | |
return enum_or_name.name | |
class EnumCallable(Generic[E]): | |
def __init__( | |
self, | |
enum_type: Type[E], | |
enum_func_registry: Union[ | |
Iterable[Tuple[Union[E, str], Transformer]], | |
Dict[Union[E, str], Transformer], | |
], | |
validate: bool = True, | |
) -> None: | |
self._enum_type = enum_type | |
if validate and not issubclass(self._enum_type, Enum): | |
raise TypeError(f"Expecting an enum type, not '{self._enum_type}'") | |
enum2func = ( | |
enum_func_registry.items() | |
if isinstance(enum_func_registry, dict) | |
else enum_func_registry | |
) | |
self._enum2func = {name_of(enum_val): func for enum_val, func in enum2func} | |
if validate: | |
vals = set(map(lambda e: e.name, self.values)) | |
for enum_name in self._enum2func.keys(): | |
if enum_name not in vals: | |
raise ValueError(f"No function for enum value named: '{enum_name}'") | |
@property | |
@lru_cache(1) | |
def values(self) -> tuple: | |
"""Obtain all distinct enum values. | |
""" | |
return tuple(self._enum_type._member_map_.values()) | |
def get_transformer(self, e: E) -> Callable[[Any], Any]: | |
"""Obtain the registered transformation function for the specific enum value. | |
""" | |
return self._enum2func[e.name] | |
def transform(self, enum_val_or_name: Union[E, str], value: Any) -> Any: | |
"""Applies the function registered for the enum's specific value. | |
""" | |
return self._enum2func[name_of(enum_val_or_name)](value) | |
def __call__(self, *args, **kwargs): | |
"""Alis for :func:`transform`. | |
""" | |
return self.transform(*args, **kwargs) | |
def test_enum_callable(): | |
ET = Enum("ET", ["a", "b", "c"]) | |
ec = EnumCallable( | |
enum_type=ET, | |
enum_func_registry={ | |
ET["a"]: lambda x: x + 1, | |
"b": lambda x: x % 2 == 0, | |
ET["c"]: lambda x: x * 10, | |
}, | |
validate=True, | |
) | |
assert ec.transform(ET["a"], 0) == 1 | |
assert ec.transform(ET["b"], 6) | |
assert ec.transform(ET["c"], -10) == -100 | |
assert ec.transform("a", -1) == 0 | |
assert ec.transform("b", 2) | |
assert ec.transform("c", 99) == 990 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment