-
-
Save bennoleslie/27aeb9065e81199f8af1 to your computer and use it in GitHub Desktop.
Code examples for blog post: http://benno.id.au/blog/2014/11/30/a-better-namedtuple
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from timeit import timeit | |
from collections import namedtuple | |
from operator import itemgetter | |
RUNS = 1000000 | |
Foo = namedtuple('Foo', 'bar baz') | |
foo = Foo(5, 10) | |
assert foo.bar == 5 | |
assert foo.baz == 10 | |
def foo_bar_times_baz(foo): | |
return foo.bar * foo.baz | |
assert foo_bar_times_baz(foo) == 50 | |
class Foo(namedtuple('Foo', 'bar baz')): | |
def bar_times_baz(self): | |
return self.bar * self.baz | |
foo = Foo(5, 10) | |
assert foo.bar_times_baz() == 50 | |
Foo = namedtuple('Foo', 'bar baz') | |
Foo.bar_times_baz = lambda self: self.bar * self.baz | |
foo = Foo(5, 10) | |
assert foo.bar_times_baz() == 50 | |
class NamedTuple(tuple): | |
__slots__ = () | |
_fields = None # Subclass must provide this | |
def __new__(_cls, *args, **kwargs): | |
if len(args) > len(_cls._fields): | |
raise TypeError("__new__ takes {} positional arguments but {} were given".format(len(_cls._fields) + 1, len(args) + 1)) | |
missing_args = tuple(fld for fld in _cls._fields[len(args):] if fld not in kwargs) | |
if len(missing_args): | |
raise TypeError("__new__ missing {} required positional arguments".format(len(missing_args))) | |
extra_args = tuple(kwargs.pop(fld) for fld in _cls._fields[len(args):] if fld in kwargs) | |
if len(kwargs) > 0: | |
raise TypeError("__new__ got an unexpected keyword argument '{}'".format(list(kwargs.keys())[0])) | |
return tuple.__new__(_cls, tuple(args + extra_args)) | |
def _make(self, iterable, new=tuple.__new__, len=len): | |
'Make a new Bar object from a sequence or iterable' | |
cls = self.__class__ | |
result = new(cls, iterable) | |
if len(result) != len(cls._fields): | |
raise TypeError('Expected {} arguments, got {}'.format(len(self._fields), len(result))) | |
return result | |
def __repr__(self): | |
'Return a nicely formatted representation string' | |
fmt = '(' + ', '.join('%s=%%r' % x for x in self._fields) + ')' | |
return self.__class__.__name__ + fmt % self | |
def _asdict(self): | |
'Return a new OrderedDict which maps field names to their values' | |
return OrderedDict(zip(self._fields, self)) | |
__dict__ = property(_asdict) | |
def _replace(_self, **kwds): | |
'Return a new Bar object replacing specified fields with new values' | |
result = _self._make(map(kwds.pop, _self._fields, _self)) | |
if kwds: | |
raise ValueError('Got unexpected field names: %r' % list(kwds)) | |
return result | |
def __getnewargs__(self): | |
'Return self as a plain tuple. Used by copy and pickle.' | |
return tuple(self) | |
def __getstate__(self): | |
'Exclude the OrderedDict from pickling' | |
return None | |
def __getattr__(self, field): | |
try: | |
idx = self._fields.index(field) | |
except ValueError: | |
raise AttributeError("'{}' NamedTuple has no attribute '{}'".format(self.__class__.__name__, field)) | |
return self[idx] | |
class FooSubClass(NamedTuple): | |
_fields = ('bar', 'baz') | |
def bar_times_baz(self): | |
return self.bar * self.baz | |
foo_sub_class = FooSubClass(5, 10) | |
assert foo_sub_class.bar_times_baz() == 50 | |
assert foo_sub_class == (5, 10) | |
# This isn't great because we can add attributes | |
foo_sub_class.blah = 29 | |
class FooSubClass(NamedTuple): | |
__slots__ = () | |
_fields = ('bar', 'baz') | |
def bar_times_baz(self): | |
return self.bar * self.baz | |
foo_sub_class = FooSubClass(5, 10) | |
assert foo_sub_class.bar_times_baz() == 50 | |
# Check we can't assign other attributes | |
try: | |
foo_sub_class.blah = 29 | |
raise Exception("Should fail with attribute error.") | |
except AttributeError: | |
pass | |
assert foo_sub_class == (5, 10) | |
direct_idx_time = timeit('foo[0]', setup='from __main__ import foo', number=RUNS) | |
direct_attr_time = timeit('foo.bar', setup='from __main__ import foo', number=RUNS) | |
sub_class_idx_time = timeit('foo[0]', setup='from __main__ import foo_sub_class as foo', number=RUNS) | |
sub_class_attr_time = timeit('foo.bar', setup='from __main__ import foo_sub_class as foo', number=RUNS) | |
print("namedtuple (idx time) ", direct_idx_time) | |
print("namedtuple (attr time)", direct_attr_time) | |
print("subclass (idx time) ", sub_class_idx_time) | |
print("sublcass (attr time)", sub_class_attr_time) | |
def optimize(cls): | |
for idx, fld in enumerate(cls._fields): | |
setattr(cls, fld, property(itemgetter(idx), doc='Alias for field number {}'.format(idx))) | |
setattr(cls, "blah", 37) | |
return cls | |
@optimize | |
class FooSubClassOptimized(NamedTuple): | |
__slots__ = () | |
_fields = ('bar', 'baz') | |
blahX = 42 | |
def bar_times_baz(self): | |
return self.bar * self.baz | |
foo_sub_class_optimized = FooSubClassOptimized(5, 10) | |
assert foo_sub_class_optimized.bar_times_baz() == 50 | |
try: | |
foo_sub_class.blah = 29 | |
raise Exception("Should fail with attribute error.") | |
except AttributeError: | |
pass | |
assert foo_sub_class_optimized == (5, 10) | |
sub_class_optimized_idx_time = timeit('foo[0]', setup='from __main__ import foo_sub_class_optimized as foo', number=RUNS) | |
sub_class_optimized_attr_time = timeit('foo.bar', setup='from __main__ import foo_sub_class_optimized as foo', number=RUNS) | |
print("subclass-opt (idx time) ", sub_class_optimized_idx_time) | |
print("subclass-opt (attr time)", sub_class_optimized_attr_time) | |
def namedfields(*fields): | |
def inner(cls): | |
if not issubclass(cls, tuple): | |
raise TypeError("namefields decorated classes must be subclass of tuple") | |
attrs = { | |
'__slots__': (), | |
} | |
methods = ['__new__', '_make', '__repr__', '_asdict', | |
'__dict__', '_replace', '__getnewargs__', | |
'__getstate__'] | |
attrs.update({attr: getattr(NamedTuple, attr) for attr in methods}) | |
attrs['_fields'] = fields | |
attrs.update({fld: property(itemgetter(idx), doc='Alias for field number {}'.format(idx)) | |
for idx, fld in enumerate(fields)}) | |
attrs.update({key: val for key, val in cls.__dict__.items() | |
if key not in ('__weakref__', '__dict__')}) | |
return type(cls.__name__, cls.__bases__, attrs) | |
return inner | |
@namedfields('bar', 'baz') | |
class FooDecorate(tuple): | |
def bar_times_baz(self): | |
return self.bar * self.baz | |
foo_decorate = FooDecorate(5, 10) | |
assert foo_decorate.bar_times_baz() == 50 | |
decorate_idx_time = timeit('foo[0]', setup='from __main__ import foo_decorate as foo', number=RUNS) | |
decorate_attr_time = timeit('foo.bar', setup='from __main__ import foo_decorate as foo', number=RUNS) | |
print("decorator (idx time) ", decorate_idx_time) | |
print("decorator (attr time)", decorate_attr_time) | |
Foo = namedfields('bar', 'baz')(type('Foo', (tuple, ), {})) | |
foo = Foo(5, 10) | |
assert foo == (5, 10) | |
print(timeit('foo.bar', setup='from __main__ import foo', number=RUNS)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment