Skip to content

Instantly share code, notes, and snippets.

@ssweber
Last active September 23, 2023 01:23
Show Gist options
  • Save ssweber/14f7127d62253a3061e4217c80c53622 to your computer and use it in GitHub Desktop.
Save ssweber/14f7127d62253a3061e4217c80c53622 to your computer and use it in GitHub Desktop.
import msgspec
from typing import Callable
def record(func: Callable):
annotations = func.__annotations__
fields = [(k, v) for k, v in annotations.items()]
# Define asdict method
def asdict(self):
return {f: getattr(self, f) for f in self.__struct_fields__}
# Define replace method
def replace(self, **new_data):
"""Return a new instance, replacing value as appropriate.
Raise a TypeError if any unexpected keys are provided.
"""
current_data = self.asdict()
if diff := set(new_data.keys()).difference(current_data.keys()):
raise TypeError(
f"{type(self).__qualname__}.replace() called with unexpected arguments:"
f"{', '.join(sorted(diff))}"
)
else:
return type(self)(**(current_data | new_data))
# Add the methods to the class namespace
namespace = {"asdict": asdict, "replace": replace}
return msgspec.defstruct(
func.__name__, fields, forbid_unknown_fields=True, frozen=True, namespace=namespace
)
# Usage:
@record
def Point(x: float, y: float):
pass
p = Point(1.0, 1.0)
print(p.asdict()) # Output: {'x': 1.0, 'y': 1.0}
new_point = p.replace(y=2.0)
print(new_point.asdict()) # Output: {'x': 1.0, 'y': 2.0}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment