Last active
September 9, 2021 12:40
-
-
Save earonesty/81e6c29fa4c54e9b67d9979ddbd8489d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""TypedEnum : type preserving enumeration metaclass.""" | |
class TypedEnum(type): | |
"""This metaclass creates an enumeration that preserves isinstance(element, type).""" | |
def __new__(mcs, cls, bases, classdict): | |
"""Discover the enum members by removing all intrinsics and specials.""" | |
object_attrs = set(dir(type(cls, (object,), {}))) | |
member_names = set(classdict.keys()) - object_attrs | |
member_names = member_names - set(name for name in member_names if name.startswith("_") and name.endswith("_")) | |
new_class = None | |
base = None | |
for attr in member_names: | |
value = classdict[attr] | |
if new_class is None: | |
# base class for all members is the type of the value | |
base = type(classdict[attr]) | |
ext_bases = (*bases, base) | |
new_class = super().__new__(mcs, cls, ext_bases, classdict) | |
setattr(new_class, "__member_names__", member_names) | |
else: | |
if not base == type(classdict[attr]): # noqa | |
raise SyntaxError("Cannot mix types in TypedEnum") | |
new_val = new_class.__new__(new_class, value) | |
setattr(new_class, attr, new_val) | |
for parent in bases: | |
new_names = getattr(parent, "__member_names__", set()) | |
member_names |= new_names | |
for attr in new_names: | |
value = getattr(parent, attr) | |
if not isinstance(value, base): | |
raise SyntaxError("Cannot mix inherited types in TypedEnum: %s from %s" % (attr, parent)) | |
# convert all inherited values to the new class | |
setattr(new_class, attr, new_class(value)) | |
return new_class | |
def __call__(cls, arg): | |
for name in cls.__member_names__: | |
if arg == getattr(cls, name): | |
return type.__call__(cls, arg) | |
raise ValueError("Invalid value '%s' for %s" % (arg, cls.__name__)) | |
@property | |
def __members__(cls): | |
"""Sufficient to make the @unique decorator work.""" | |
class FakeEnum: # pylint: disable=too-few-public-methods | |
"""Object that looks a bit like an Enum instance.""" | |
def __init__(self, name, value): | |
self.name = name | |
self.value = value | |
return {name: FakeEnum(name, getattr(cls, name)) for name in cls.__member_names__} | |
def __iter__(cls): | |
"""List all enum values.""" | |
return (getattr(cls, name) for name in cls.__member_names__) | |
def __len__(cls): | |
"""Get number of enum values.""" | |
return len(cls.__member_names__) | |
def test_meta(): | |
"""Basic inline tests.""" | |
# pylint: disable=too-few-public-methods, exec-used | |
# ints and strs work | |
class IntEnum(metaclass=TypedEnum): | |
x = 3 | |
y = 4 | |
z = 5 | |
assert isinstance(IntEnum.x, int) | |
assert isinstance(IntEnum.x, IntEnum) | |
class StrEnum(metaclass=TypedEnum): | |
x = "1" | |
y = "2" | |
z = "3" | |
assert isinstance(StrEnum.x, str) | |
assert isinstance(StrEnum.y, StrEnum) | |
# iteration and membership work | |
for ent in StrEnum: | |
assert isinstance(ent, StrEnum) | |
assert max(StrEnum) is StrEnum.z | |
assert "3" in StrEnum | |
assert StrEnum.z in StrEnum | |
# mismatched types is a syntax error | |
try: | |
exec("""class MixedEnum(metaclass=TypedEnum): | |
y = "1" | |
x = 2 | |
z = "3" | |
""") | |
assert False | |
except SyntaxError: | |
pass | |
assert StrEnum("3") == "3" | |
assert StrEnum("3") in StrEnum | |
try: | |
StrEnum("4") | |
assert False | |
except ValueError: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment