Created
August 20, 2020 17:25
-
-
Save jcrist/33aeab0564871868ca2e1bbec6289f35 to your computer and use it in GitHub Desktop.
Python struct class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
import inspect | |
def _extract_attributes(bases, attrs): | |
arg_fields = {} | |
kwarg_fields = {} | |
existing_slots = set() | |
# Walk up the bases, validating and merging defaults | |
for base in reversed(bases): | |
if not issubclass(base, Struct): | |
raise ValueError("Structs cannot subclass from non-structs") | |
fields = base.__struct_fields__ | |
defaults = base.__struct_defaults__ | |
existing_slots.update(base.__slots__) | |
for f in fields[:-len(defaults)]: | |
arg_fields[f] = None | |
kwarg_fields.pop(f, None) | |
for f, d in zip(fields[-len(defaults):], defaults): | |
kwarg_fields[f] = d | |
arg_fields.pop(f, None) | |
# Add new fields | |
annotations = attrs.get("__annotations__", {}) | |
for f in annotations: | |
if f in attrs: | |
arg_fields.pop(f, None) | |
kwarg_fields[f] = attrs.pop(f) | |
else: | |
kwarg_fields.pop(f, None) | |
arg_fields[f] = None | |
fields = list(arg_fields) | |
fields.extend(kwarg_fields) | |
fields = tuple(fields) | |
defaults = tuple(kwarg_fields.values()) | |
new_slots = tuple(sorted(set(fields).difference(existing_slots))) | |
return fields, defaults, new_slots | |
class _StructMeta(type): | |
def __new__(metacls, name, bases, attrs): | |
if bases: | |
for field in ["__init__", "__new__", "__slots__"]: | |
if field in attrs: | |
raise ValueError(f"Cannot override {field}") | |
fields, defaults, slots = _extract_attributes(bases, attrs) | |
else: | |
slots = ("__weakref__",) | |
fields = () | |
defaults = () | |
attrs["__struct_fields__"] = fields | |
attrs["__struct_defaults__"] = defaults | |
attrs["__slots__"] = slots | |
return type.__new__(metacls, name, bases, attrs) | |
@property | |
def __signature__(self): | |
parameters = [] | |
annotations = typing.get_type_hints(self) | |
ndefaults = len(self.__struct_defaults__) | |
npos = len(self.__struct_fields__) - ndefaults | |
for n, f in enumerate(self.__struct_fields__): | |
if n < npos: | |
default = inspect.Parameter.empty | |
else: | |
default = self.__struct_defaults__[n - npos] | |
kind = inspect.Parameter.POSITIONAL_OR_KEYWORD | |
annotation = annotations.get(f, inspect.Parameter.empty) | |
parameters.append(inspect.Parameter(f, kind, annotation=annotation, default=default)) | |
return inspect.Signature(parameters=parameters) | |
class Struct(metaclass=_StructMeta): | |
"""Base class for defining struct types. | |
Given a type-annotated class, generates appropriate `__init__`, `__eq__` | |
and `__repr__` methods. Default values can also be provided. | |
Examples: | |
--------- | |
>>> class Dog(Struct): | |
... name: str | |
... age: int | |
... is_good: bool = True | |
>>> Dog("snickers", age=13) | |
Dog(name='snickers', age=13, is_good=True) | |
Subclassing works as well. New parameters are appended to the end of the | |
signature, and any default values updated. | |
>>> class Labrador(Dog): | |
... color: str = "black" | |
>>> Labrador("chip", 12, color="chocolate") | |
Labrador(name='chip', age=12, is_good=True, color='chocolate') | |
""" | |
def __init__(self, *args, **kwargs): | |
nargs = len(args) | |
nkwargs = len(kwargs) | |
npos = len(self.__struct_fields__) - len(self.__struct_defaults__) | |
for n, field in enumerate(self.__struct_fields__): | |
if field in kwargs: | |
if n < nargs: | |
raise ValueError(f"Argument {field} given by name and position") | |
nkwargs -= 1 | |
val = kwargs[field] | |
elif n < nargs: | |
val = args[n] | |
elif n < npos: | |
raise ValueError(f"Missing required argument {field}") | |
else: | |
val = self.__struct_defaults__[n - npos] | |
setattr(self, field, val) | |
if nkwargs: | |
raise ValueError("Extra keyword arguments provided") | |
def __repr__(self): | |
cls_name = type(self).__name__ | |
parts = [cls_name, "("] | |
last = len(self.__struct_fields__) - 1 | |
for n, f in enumerate(self.__struct_fields__): | |
parts.append(f) | |
parts.append("=") | |
parts.append(repr(getattr(self, f))) | |
if n != last: | |
parts.append(", ") | |
parts.append(")") | |
return "".join(parts) | |
def __eq__(self, other): | |
if type(self) is not type(other): | |
return False | |
for f in self.__struct_fields__: | |
if getattr(self, f) != getattr(other, f): | |
return False | |
return True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment