Skip to content

Instantly share code, notes, and snippets.

@lebrice
Last active November 23, 2023 18:12
Show Gist options
  • Save lebrice/728f0e0218dcef6f74cbebd70af5857e to your computer and use it in GitHub Desktop.
Save lebrice/728f0e0218dcef6f74cbebd70af5857e to your computer and use it in GitHub Desktop.
Conditional dataclass fields. The default_factory can now take as an argument the value of other fields on the dataclass.
from __future__ import annotations
import inspect
from dataclasses import dataclass, field, Field, fields
from typing import Any, Callable, TypeVar, overload
from logging import getLogger as get_logger
logger = get_logger("conditional_fields")
T = TypeVar("T")
_inputs_key: str = "inputs"
_default_factory_key: str = "default_factory"
# A sentinel object to detect if a parameter is waiting for its dependencies or not. Use
# a class to give it a better repr.
class _MISSING_CONDITIONAL_TYPE:
pass
MISSING_CONDITIONAL = _MISSING_CONDITIONAL_TYPE()
# IDEA: Could maybe create a ConditionalField object, that inherits from `dataclasses.Field`?
@overload
def conditional_field(
default_factory: Callable[..., T],
inputs: None = None,
) -> T:
...
# The number of specified 'inputs' must match exactly the number of arguments to the function.
@overload
def conditional_field(default_factory: Callable[[Any], T], inputs: str) -> T:
...
@overload
def conditional_field(default_factory: Callable[[Any, Any], T], inputs: tuple[str, str]) -> T:
...
@overload
def conditional_field(
default_factory: Callable[[Any, Any, Any], T], inputs: tuple[str, str, str]
) -> T:
...
def conditional_field(
default_factory: Callable[..., T], inputs: str | tuple[str, ...] | None = None
) -> T:
if inputs is None:
signature = inspect.signature(default_factory)
input_names = tuple(signature.parameters)
elif isinstance(inputs, str):
input_names = (inputs,)
else:
input_names = tuple(inputs)
return field(
default=MISSING_CONDITIONAL, # type: ignore
metadata={
_inputs_key: input_names,
_default_factory_key: default_factory,
},
)
def is_conditional(field: Field) -> bool:
return _inputs_key in field.metadata and _default_factory_key in field.metadata
def _get_input_names(f: Field) -> list[str]:
assert is_conditional(f)
return list(f.metadata[_inputs_key])
def _get_conditional_default_factory(f: Field) -> Callable:
assert is_conditional(f)
return f.metadata[_default_factory_key]
def set_conditionals(obj) -> None:
"""Sets the conditional fields on `obj` by resolving the dependencies and calling the factories."""
t_fields = fields(obj)
if not any(is_conditional(f) for f in t_fields):
return
def _is_set(f: Field) -> bool:
return getattr(obj, f.name) is not MISSING_CONDITIONAL
def _get_input_fields(f: Field) -> list[Field]:
assert is_conditional(f)
input_names = _get_input_names(f)
for input_name in input_names:
if not hasattr(obj, input_name):
raise RuntimeError(
f"Field {f.name} is conditioned on the value of '{input_name}', but there is "
f"no field with that name on type {type(obj)}!"
)
return [f for f in fields(obj) if f.name in input_names]
def _get_conditional_fields_left() -> list[Field]:
return sorted(
[f for f in t_fields if is_conditional(f) and f.init and not _is_set(f)],
key=lambda f: f.name,
)
conditional_fields_left = _get_conditional_fields_left()
while conditional_fields_left:
# Find all the fields whose dependencies are all set.
leaves = [
f
for f in conditional_fields_left
if all(_is_set(input_field) for input_field in _get_input_fields(f))
]
if not leaves:
raise RuntimeError(
f"There are conditional fields left, but no leaves, so there must be a "
f"dependency cycle between fields {[f.name for f in conditional_fields_left]}!"
)
for field in leaves:
input_names = _get_input_names(field)
default_factory = _get_conditional_default_factory(field)
logger.debug(f"Instantiating field {field.name} using its dependencies {input_names}")
factory_fn_inputs = {name: getattr(obj, name) for name in input_names}
value = default_factory(**factory_fn_inputs)
setattr(obj, field.name, value)
conditional_fields_left = _get_conditional_fields_left()
@dataclass
class HasConditionalFields:
def __post_init__(self):
set_conditionals(self)
if __name__ == "__main__":
@dataclass
class Bob(HasConditionalFields):
name: str = "Bob Jones"
age: int = 32
gamer_name: str = conditional_field(
lambda name, age: f"xXx_{name}_{age}_xXx",
)
email: str = conditional_field(lambda gamer_name: f"{gamer_name}@gmail.com")
bob = Bob()
print(bob)
# Bob(name='Bob Jones', age=32, gamer_name='xXx_Bob Jones_32_xXx', email='Bob Jones@gmail.com', gamer_email='xXx_Bob Jones_32_xXx@gmail.com')
from setuptools import setup
import os
import shutil
if __name__ == "__main__":
if not os.path.exists('conditional_fields'):
os.mkdir('conditional_fields')
shutil.copyfile('conditional_fields.py', 'conditional_fields/__init__.py')
setup(
name='conditional_fields',
version='0.0.1',
description='Conditional dataclass fields.',
author='Fabrice Normandin',
author_email='fabrice.normandin@gmail.com',
packages=['conditional_fields'], # Same as name
python_requires=">=3.7",
# NB: “state.py” requires “lmdb” and “cbor2”, but one doesn't have to use it,
# they are *optional* so to speak.
# ⌥ use square bracket dependencies hence
install_requires=[], # External packages as dependencies
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment