Skip to content

Instantly share code, notes, and snippets.

@farsil
Created November 17, 2019 21:22
Show Gist options
  • Save farsil/fddcad5bdd6749173f69ea33eee5bb4f to your computer and use it in GitHub Desktop.
Save farsil/fddcad5bdd6749173f69ea33eee5bb4f to your computer and use it in GitHub Desktop.
Records in python, the equivalent of a C struct.
#!/usr/bin/env python3
"""Record factory. A record is the equivalent of a C struct."""
import sys
import keyword
class_template = """\
import collections
class {typename}:
'{typename}({arg_list})'
__slots__ = {field_names!r}
def __init__(self, {init_args}):
'Create new instance of {typename}.'
{self_fields} = {arg_list}
def __repr__(self):
'Return a nicely formatted representation string.'
return self.__class__.__name__ + '({repr_fmt})'.format({self_fields})
def __iter__(self):
'Iterates over {typename} fields and values.'
return zip(self.__slots__, self.__values())
def __eq__(self, other):
'Returns True if *self* is equal to *other*.'
return {eq_stmt}
@classmethod
def _make(cls, iterable):
'Make a new {typename} object from *iterable*.'
return {typename}(*iterable)
@property
def _fields(self):
'Returns the names of the fields in the order they were defined.'
return self.__slots__
def __values(self):
'Iterates through values in the order fields were defined.'
{values_stmt}
def _asdict(self):
'Return a new OrderedDict which maps field names to their values.'
result = collections.OrderedDict()
{asdict_result} = {self_fields}
return result
def _replace(self, **kwargs):
'Return a new {typename} object replacing specified fields with new values'
result = self._make(map(kwargs.pop, self.__slots__, self.__values()))
if kwargs:
raise ValueError('Got unexpected field names: ' + str(list(kwargs)))
return result\
"""
def validate(typename, field_names):
'Internal use, checks validity of typename and field_names.'
if isinstance(field_names, str):
field_names = field_names.replace(',', ' ').split()
field_names = list(map(str, field_names))
typename = str(typename)
# avoid name clashes
for name in [typename] + field_names:
if not isinstance(name, str):
raise TypeError('Type names and field names must be strings')
if not name.isidentifier():
raise ValueError('Type names and field names must be valid '
'identifiers: ' + name)
if keyword.iskeyword(name):
raise ValueError('Type names and field names cannot be a '
'keyword: ' + name)
# check for duplicates
seen = set()
for name in field_names:
if name.startswith('_'):
raise ValueError('Field names cannot start with an underscore: ' +
name)
if name in seen:
raise ValueError('Encountered duplicate field name: ' + name)
seen.add(name)
return typename, field_names
def compile_(typename, source, module=None):
'Internal use, compiles struct source.'
# pylint: disable=exec-used, protected-access
namespace = {__name__ : 'struct_{:s}'.format(typename)}
exec(source, namespace)
code = namespace[typename]
code._source = source
if module is None:
try:
module = sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
if module is not None:
code.__module__ = module
return code
def record(typename, field_names, *, module=None):
'Factory method that creates a record class'
typename, field_names = validate(typename, field_names)
source = class_template.format( \
typename=typename, \
field_names=tuple(field_names), \
arg_list=', '.join(field_names), \
init_args=', '.join(f + '=None' for f in field_names), \
eq_stmt=' and '.join('self.' + f + ' == other.' + f \
for f in field_names), \
repr_fmt=', '.join(f + '={}' for f in field_names), \
self_fields=', '.join('self.' + f for f in field_names), \
values_stmt='; '.join("yield self." + f for f in field_names), \
asdict_result=', '.join("result['" + f + "']" for f in field_names) \
)
return compile_(typename, source, module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment