Skip to content

Instantly share code, notes, and snippets.

@EpicWink
Last active October 4, 2022 12:02
Show Gist options
  • Save EpicWink/195f43286e47c26eef7d5eb3263fee75 to your computer and use it in GitHub Desktop.
Save EpicWink/195f43286e47c26eef7d5eb3263fee75 to your computer and use it in GitHub Desktop.
Support non-defaulted after defaulted fields in dataclasses
"""Patch ``dataclasses`` to support optional after required fields.
Fields used in ``__init__`` without defaults are currently not allowed
after fields with defaults, due to the specification in PEP 557. This
patch allows these fields, but makes them required keyword-only
parameters to ``__init__``.
To apply this patch, simply import this module before defining any
dataclasses.
"""
# Copyright 2020 Laurie O
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this
# software and associated documentation files (the "Software"), to deal in the Software
# without restriction, including without limitation the rights to use, copy, modify,
# merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to the following
# conditions:
#
# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import sys
import dataclasses
import functools as ft
@ft.wraps(dataclasses._init_fn)
def _init_fn(fields, frozen, has_post_init, self_name, globals_=None):
"""Build ``__init__`` for a data-class."""
py_37 = sys.version_info < (3, 7, 6)
locals_ = {f"_type_{f.name}": f.type for f in fields}
extra = {
"MISSING": dataclasses.MISSING,
"_HAS_DEFAULT_FACTORY": dataclasses._HAS_DEFAULT_FACTORY,
}
if py_37:
assert globals_ is None
globals_ = extra
else:
assert globals_ is not None
locals_.update(extra)
body_lines = []
for f in fields:
line = dataclasses._field_init(
f, frozen, globals_ if py_37 else locals_, self_name
)
if line:
body_lines.append(line)
if has_post_init:
params_str = ",".join(
f.name for f in fields if f._field_type is dataclasses._FIELD_INITVAR
)
body_line = f"{self_name}.{dataclasses._POST_INIT_NAME}({params_str})"
body_lines.append(body_line)
if not body_lines:
body_lines = ["pass"]
# Edit: args after defaulted args are keyword-only
seen_default = False
keyword_only = False
args = [self_name]
for f in fields:
if f.init:
has_default = f.default is not dataclasses.MISSING
has_default_factory = f.default_factory is not dataclasses.MISSING
if has_default or has_default_factory:
seen_default = True
elif seen_default and not keyword_only:
keyword_only = True
args.append("*")
args.append(dataclasses._init_param(f))
return dataclasses._create_fn(
"__init__",
args,
body_lines,
locals=locals_,
globals=globals_,
return_type=None,
)
dataclasses._init_fn = _init_fn
"""Patch ``dataclasses`` to support optional after required fields.
Fields used in ``__init__`` without defaults are currently not allowed
after fields with defaults, due to the specification in PEP 557. This
patch allows these fields, but makes them required keyword-only
parameters to ``__init__``.
To apply this patch, simply import this module before defining any
dataclasses.
This file also supports Python 3.10+
"""
# Copyright 2020 Laurie O
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this
# software and associated documentation files (the "Software"), to deal in the Software
# without restriction, including without limitation the rights to use, copy, modify,
# merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to the following
# conditions:
#
# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import sys
import copy
import dataclasses
import functools as ft
@ft.wraps(dataclasses._init_fn)
def _init_fn(fields, frozen, has_post_init, self_name, globals_=None):
"""Build ``__init__`` for a data-class."""
py_37 = sys.version_info < (3, 7, 6)
locals_ = {f"_type_{f.name}": f.type for f in fields}
extra = {
"MISSING": dataclasses.MISSING,
"_HAS_DEFAULT_FACTORY": dataclasses._HAS_DEFAULT_FACTORY,
}
if py_37:
assert globals_ is None
globals_ = extra
else:
assert globals_ is not None
locals_.update(extra)
body_lines = []
for f in fields:
line = dataclasses._field_init(
f, frozen, globals_ if py_37 else locals_, self_name
)
if line:
body_lines.append(line)
if has_post_init:
params_str = ",".join(
f.name for f in fields if f._field_type is dataclasses._FIELD_INITVAR
)
body_line = f"{self_name}.{dataclasses._POST_INIT_NAME}({params_str})"
body_lines.append(body_line)
if not body_lines:
body_lines = ["pass"]
# Edit: args after defaulted args are keyword-only
seen_default = False
keyword_only = False
args = [self_name]
for f in fields:
if f.init:
has_default = f.default is not dataclasses.MISSING
has_default_factory = f.default_factory is not dataclasses.MISSING
if has_default or has_default_factory:
seen_default = True
elif seen_default and not keyword_only:
keyword_only = True
args.append("*")
args.append(dataclasses._init_param(f))
return dataclasses._create_fn(
"__init__",
args,
body_lines,
locals=locals_,
globals=globals_,
return_type=None,
)
if sys.version_info < (3, 10):
dataclasses._init_fn = _init_fn
@ft.wraps(dataclasses._process_class)
def _process_class(cls, *args, **kwargs):
"""Convert into a data-class."""
def _get_field_py310(name__, type__):
original = getattr(cls, name__, None)
if isinstance(original, dataclasses.Field):
setattr(cls, name__, copy.copy(original))
f = dataclasses._get_field(cls, name__, type__, False)
if isinstance(original, dataclasses.Field):
setattr(cls, name__, original)
return f
# Use built-in keyword-only argument handling
# Check to see if we're already in keyword-only mode
all_base_fields = {}
seen_default = False
for b in cls.__mro__[-1:0:-1]:
base_fields = getattr(b, dataclasses._FIELDS, None) or {}
all_base_fields.update(base_fields)
if any(
f.default is not dataclasses.MISSING or
f.default_factory is not dataclasses.MISSING
for f in base_fields.values()
):
seen_default = True
break
cls_annotations = cls.__dict__.get("__annotations__", {})
# See if previous fields have had defaults added
to_convert = []
for name, type_ in cls_annotations.items():
if name in all_base_fields: # all_base_fields won't have KW_ONLY
f = _get_field_py310(name, type_)
base_f = all_base_fields[name]
if (
base_f.default is dataclasses.MISSING and
base_f.default_factory is dataclasses.MISSING and (
f.default is not dataclasses.MISSING or
f.default_factory is not dataclasses.MISSING
)
):
base_names = list(all_base_fields)
index = base_names.index(name)
for name_ in base_names[index + 1:]:
if not all_base_fields[name_].kw_only:
to_convert.append(all_base_fields[name_])
# Add keyword-only mode if/when required
names = list(cls_annotations)
for i, name in enumerate(names):
type_ = cls_annotations[name]
if type_ == dataclasses.KW_ONLY:
break
f = _get_field_py310(name, type_)
if (
seen_default and
f.default is dataclasses.MISSING and
f.default_factory is dataclasses.MISSING
):
# Insert KW_ONLY marker
kw_only = {n: cls_annotations.pop(n) for n in names[i:]}
cls_annotations["_"] = dataclasses.KW_ONLY
for name_ in names[i:]:
cls_annotations[name_] = kw_only[name_]
break
if f._field_type is dataclasses._FIELD and (
f.default is not dataclasses.MISSING or
f.default_factory is not dataclasses.MISSING
):
seen_default = True
# Make previous fields keyword-only if required
if to_convert:
if not any(v == dataclasses.KW_ONLY for v in cls_annotations.values()):
cls_annotations["_"] = dataclasses.KW_ONLY
for f in to_convert:
cls_annotations[f.name] = f.type
return _dataclasses_process_class(cls, *args, **kwargs)
if sys.version_info >= (3, 10):
_dataclasses_process_class = dataclasses._process_class
dataclasses._process_class = _process_class
@EpicWink
Copy link
Author

EpicWink commented Mar 11, 2021

...but I guess that mypy will still report errors, right?

Yes, if you want type-checkers to no show any errors, you may either need to disable type-checking on data-classes, or use attr.ib(kw_only=True)

Edit: or use Python 3.10+ with kw_only=True or KW_ONLY

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment