Created
November 21, 2017 16:42
-
-
Save raph-amiard/e172252f92f8694bf449eb32871fe2b6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This module implements a record facility in Python. The aim is to simplify the | |
case where you want a simple data class, which is used mainly to store data, | |
and you don't want to implement the regular python boilerplate of: | |
* Assigning fields in the constructors | |
* Sanity checking number of arguments, types, optional or not, etc | |
* Implementing repr | |
* Implementing hash and structural equality (not yet implemented) | |
Here is how you declare a basic record | |
>>> class Point(Record): | |
... x = Field(type=int) | |
... y = Field(type=int) | |
And here is how you use it | |
>>> p = Point(12, 15) | |
>>> p | |
<Point x=12, y=15> | |
You can use keyword associations | |
>>> Point(y=15, x=12) | |
<Point x=12, y=15> | |
If you pass too much or too few arguments, or you pass an association twice, | |
you will get errors. | |
>>> Point(1, 2, 3) | |
Traceback (most recent call last): | |
... | |
RecordError: Too many args to init | |
>>> Point(1, x=2) | |
Traceback (most recent call last): | |
... | |
RecordError: Field already set: x | |
>>> Point(x=2) | |
Traceback (most recent call last): | |
... | |
RecordError: Missing field: y | |
Fields are inherited so you can create hierarchies of records. The order of | |
fields, relevant for initialization, is from the base class to the most | |
specific derivation. | |
>>> class Node(Record): | |
... name = Field(type=str, optional=True) | |
>>> class Add(Node): | |
... left = Field(type=Node) | |
... right = Field(type=Node) | |
>>> class Int(Node): | |
... val = Field(type=int) | |
>>> Add("add_1", Int(val=12), Int(val=15)) | |
<Add name=add_1, left=<Int name=None, val=12>, right=<Int name=None, val=15>> | |
""" | |
from itertools import count | |
class Field(object): | |
_counter = iter(count(0)) | |
def __init__(self, type=None, optional=False): | |
self._index = next(self._counter) | |
self._name = None | |
self.optional = optional | |
self.type = type | |
def set_val(self, record, val): | |
setattr(record, self._name, val) | |
def get_val(self, record): | |
return getattr(record, self._name, None) | |
class RecordError(Exception): | |
pass | |
class RecordMC(type): | |
def __new__(mcs, name, bases, dct): | |
cls = type.__new__(mcs, name, bases, dct) | |
fields = [] | |
for k, v in dct.items(): | |
if isinstance(v, Field): | |
v._name = k | |
fields.append(v) | |
fields = sorted(fields, key=lambda f: f._index) | |
fields_dict = {f._name: f for f in fields} | |
cls._own_fields = fields | |
cls._own_fields_dict = fields_dict | |
return cls | |
class Record(object): | |
__metaclass__ = RecordMC | |
def __init__(self, *args, **kwargs): | |
set_fields = set() | |
for i, arg in enumerate(args): | |
try: | |
f = self._fields()[i] | |
f.set_val(self, arg) | |
set_fields.add(f) | |
except IndexError: | |
raise RecordError("Too many args to init") | |
for k, v in kwargs.items(): | |
try: | |
f = self._fields_dict()[k] | |
if f in set_fields: | |
raise RecordError("Field already set: {}".format(f._name)) | |
f.set_val(self, v) | |
set_fields.add(f) | |
except KeyError: | |
raise RecordError("Wrong field name: {}".format(k)) | |
rem = set(self._fields()) - set_fields | |
if rem: | |
for f in rem: | |
if f.optional: | |
f.set_val(self, None) | |
else: | |
raise RecordError("Missing field: {}".format(f._name)) | |
def __repr__(self): | |
return "<{} {}>".format(self.__class__.__name__, ", ".join( | |
"{}={}".format(f._name, f.get_val(self)) | |
for f in self._fields() | |
)) | |
@classmethod | |
def _fields(cls): | |
b = cls.__bases__[0] | |
if b is object: | |
return [] | |
return b._fields() + cls._own_fields | |
@classmethod | |
def _fields_dict(cls): | |
b = cls.__bases__[0] | |
if b is object: | |
return {} | |
return dict(b._fields_dict(), **cls._own_fields_dict) | |
def is_a(self, cls): | |
return isinstance(self, cls) | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment