Skip to content

Instantly share code, notes, and snippets.

@huzecong
Last active Feb 22, 2022
Embed
What would you like to do?
A super-enhanced version of namedtuple that supports multiple inheritance and arbitrary field orders.
# Copyright (c) 2021 Zecong Hu
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
# OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.
import collections
import typing
__all__ = [
"Options",
]
class OptionsMeta(typing.NamedTupleMeta):
def __new__(mcs, typename, bases, namespace):
if namespace.get('_root', False):
# The created class is `Options`, skip.
return super().__new__(mcs, typename, bases, namespace)
# Gather fields from annotations of current class and base classes.
cur_fields = namespace.get('__annotations__', {})
fields = {}
field_sources = {} # which base class does the name came from
field_defaults = {}
for base in bases:
if issubclass(base, Options) and hasattr(base, '_fields'):
# Base class is a concrete subclass of `Options`.
for name in base._fields:
if name in cur_fields:
# Make sure not to overwrite redefined fields.
continue
if name in fields:
# Overlapping field that is not redefined.
raise TypeError(
f"Base class {base} contains field {name}, which "
f"is defined in other base class "
f"{field_sources[name]}")
fields[name] = base.__annotations__[name]
field_sources[name] = base
if name in base._field_defaults:
field_defaults[name] = base._field_defaults[name]
fields.update(cur_fields)
if len(fields) == 0:
raise ValueError("Options class must contain at least one field")
for name, value in field_defaults.items():
namespace.setdefault(name, value)
# Reorder fields to put those without default values in front.
fields_with_default = [name for name in fields if name in namespace]
reordered_fields = (sorted(set(fields).difference(fields_with_default)) +
sorted(fields_with_default))
namespace['__annotations__'] = collections.OrderedDict(
[(name, fields[name]) for name in reordered_fields])
# Let `NamedTupleMeta` create a annotated `namedtuple` for us.
# Note that `bases` is not used here so we just set it to `None`.
nm_tpl = super().__new__(mcs, typename, None, namespace)
# Rewrite `__new__` method to make all arguments keyword-only.
# This is very hacky code. Do not try this at home.
arg_list = ''.join(name + ', ' # watch out for singleton tuples
for name in reordered_fields)
s = (f"""
def __new__(_cls, *args, {arg_list}):
if len(args) > 0:
raise TypeError("Instances of Options class must be created "
"with keyword arguments.")
return _tuple_new(_cls, ({arg_list}))
""").strip()
new_method_namespace = {'_tuple_new': tuple.__new__,
'__name__': f'namedtuple_{typename}'}
exec(s, new_method_namespace)
__new__ = new_method_namespace['__new__']
__new__.__qualname__ = f"{typename}.__new__"
__new__.__doc__ = nm_tpl.__new__.__doc__
__new__.__annotations__ = nm_tpl.__new__.__annotations__
__new__.__kwdefaults__ = {name: namespace[name]
for name in fields_with_default}
nm_tpl.__new__ = __new__
# Wrap the return type in `OptionsMeta` so it can be subclassed.
new_namespace = nm_tpl.__dict__.copy()
new_namespace['_bases'] = bases
# Also keep base classes of the `namedtuple` (i.e., the `tuple` class),
# so we can call `tuple.__new__`.
options_type = type.__new__(mcs, typename, nm_tpl.__bases__, new_namespace)
options_type.__bases__ = tuple(options_type.__bases__)
return options_type
def mro(cls):
default_mro = super().mro()
# `Options` does not define `_bases`, so we don't do anything about it.
if hasattr(cls, '_bases'):
# `default_mro` should be `[cls, tuple, object]`.
# `c3merge` and `c3mro` are implementations of the C3 linearization
# algorithm, which unluckily aren't provided as APIs.
return c3merge([
default_mro[:1],
*[base.__mro__ for base in cls._bases],
default_mro[1:]])
return default_mro
class Options(metaclass=OptionsMeta):
_root = True
def __new__(cls, *args, **kwargs):
# Copied from typing.Generic.
if cls is Options:
# Prevent instantiation of `Options` class.
raise TypeError("Type Options cannot be instantiated; "
"it can be used only as a base class")
if (super().__new__ is object.__new__ and
cls.__init__ is not object.__init__):
obj = super().__new__(cls)
else:
obj = super().__new__(cls, *args, **kwargs)
return obj
def c3merge(sequences):
r"""Adapted from https://www.python.org/download/releases/2.3/mro/"""
# Make sure we don't actually mutate anything we are getting as input.
sequences = [list(x) for x in sequences]
result = []
while True:
# Clear out blank sequences.
sequences = [x for x in sequences if x]
if not sequences:
return result
# Find the first clean head.
for seq in sequences:
head = seq[0]
# If this is not a bad head (i.e., not in any other sequence)
if not any(head in s[1:] for s in sequences):
break
else:
raise Error("inconsistent hierarchy")
# Move the head from the front of all sequences to the end of results.
result.append(head)
for seq in sequences:
if seq[0] == head:
del seq[0]
return result
@huzecong
Copy link
Author

huzecong commented Feb 5, 2022

It seems like Python 3.9 changed the implementation details for typing.NamedTupleMeta; it added this new assertion but in our case the Options class has no base classes. So this thing here is not compatible with Python 3.9.

Looking back at this piece of code two years later, I would say that it's too complicated for what it's trying to solve, and involves too much magic. There are much better alternatives that more or less achieve the same thing, e.g. attrs or the built-in dataclasses. If you care about performance and memory, attrs also offers a slots class.

@MatthiasLohr
Copy link

MatthiasLohr commented Feb 6, 2022

Thanks for your reply! Actually, I stumbled over this code by searching for NamedTuple with inheritance support. Any recommendation for that?

@huzecong
Copy link
Author

huzecong commented Feb 6, 2022

There's no built-in way to support inheritance for NamedTuples per se, mostly because it's still a tuple that supports __getitem__ via an index, and there's no well-defined behavior for adding fields to a tuple. But, as I said in the previous comments, you could consider using attrs or dataclasses for more or less the same functionality. The attrs docs even has a page that compares it with namedtuples.

@cdce8p
Copy link

cdce8p commented Feb 22, 2022

Thanks for your reply! Actually, I stumbled over this code by searching for NamedTuple with inheritance support. Any recommendation for that?

@MatthiasLohr A while back I experimented with that too. As far as I remember, I got it to work on Python 3.7+. Ultimately, I ended up going with dataclasses since support for them is simply much better compared to a custom NamedTuple implementation. If it helps you, here's what I did. No guarantees though 🙂

"""Requires Python 3.7 -> preserve dict insertion order"""
from __future__ import annotations

import sys
import typing

# attributes prohibited to set in NamedTuple class syntax
_prohibited = frozenset({'__new__', '__init__', '__slots__', '__getnewargs__',
                         '_fields', '_field_defaults',
                         '_make', '_replace', '_asdict', '_source'})
_special = frozenset({'__module__', '__name__', '__annotations__'})


class NamedTupleMeta(type):
    def __new__(cls, typename, bases, ns):
        types = ns.get('__annotations__', {})
        default_names = []
        for field_name in types:
            if field_name in ns:
                default_names.append(field_name)
            elif default_names:
                raise TypeError(f"Non-default namedtuple field {field_name} "
                                f"cannot follow default field"
                                f"{'s' if len(default_names) > 1 else ''} "
                                f"{', '.join(default_names)}")
        defaults = tuple(ns[n] for n in default_names)
        if sys.version_info >= (3, 9):
            nm_tpl = typing._make_nmtuple(typename, types.items(),
                                          defaults=defaults,
                                          module=ns['__module__'])
        else:
            nm_tpl = typing._make_nmtuple(typename, types.items())
            nm_tpl.__new__.__annotations__ = dict(types)
            nm_tpl.__new__.__defaults__ = defaults
            nm_tpl._field_defaults = {n: ns[n] for n in default_names}
        # update from user namespace without overriding special namedtuple attributes
        for key in ns:
            if key in _prohibited:
                raise AttributeError("Cannot overwrite NamedTuple attribute " + key)
            if key not in _special and key not in nm_tpl._fields:
                setattr(nm_tpl, key, ns[key])
        return nm_tpl


class OptionsMeta(NamedTupleMeta):
    def __new__(cls, typename, bases, ns):
        cur_fields = ns.get("__annotations__", {})
        fields = {}
        field_sources = {}
        field_defaults = {}

        for base in bases:
            if hasattr(base, "_fields"):
                for name in base._fields:
                    if name in cur_fields:
                        # Don't overwrite redefined fields
                        continue
                    if name in fields:
                        # Overlapping field that is not redefined.
                        raise TypeError(
                            f"Base class {base} contains field {name}, which "
                            f"is defined in other base class "
                            f"{field_sources[name]}")
                    fields[name] = base.__annotations__[name]
                    field_sources[name] = base
                    if name in base._field_defaults:
                        field_defaults[name] = base._field_defaults[name]
        fields.update(cur_fields)
        if len(fields) == 0:
            raise ValueError("Options class must contain at least one field")
        for name, value in field_defaults.items():
            ns.setdefault(name, value)

        # Reorder fields to put those without default values in front.
        fields_with_default = [name for name in fields if name in ns]
        annotations = {name: val for name, val in fields.items()
                       if name not in fields_with_default}
        annotations.update({name: val for name, val in fields.items()
                            if name in fields_with_default})
        ns["__annotations__"] = annotations

        nm_tpl = super().__new__(cls, typename, None, ns)
        bases = bases + nm_tpl.__bases__
        return type.__new__(cls, typename, bases, nm_tpl.__dict__.copy())


def Options():
    raise TypeError("Options can only be used as base class")

_Options = type.__new__(OptionsMeta, 'Options', (), {})
Options.__mro_entries__ = lambda bases: (_Options,)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment