Skip to content

Instantly share code, notes, and snippets.

Created March 13, 2023 08:12
Show Gist options
  • Save drew2a/3168571b2051cd49cd49032498c64c00 to your computer and use it in GitHub Desktop.
Save drew2a/3168571b2051cd49cd49032498c64c00 to your computer and use it in GitHub Desktop.
Pydantic configuration BaseModels for IPv8.
If you want to see the available configuration options, print the schema as follows:
from json import dumps
print(dumps(format_schema_recursive(IPv8Configuration), indent=4))
You can use the ``IPv8Configuration by simply passing it to an ``IPv8`` constructor:
ipv8_instance = IPv8(IPv8Configuration())
import base64
import sys
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Type, Union
from pydantic import BaseModel, Field, validator
from pydantic.validators import dict_validator
from .configuration import DISPERSY_BOOTSTRAPPER, default
from .keyvault.crypto import default_eccrypto
class BaseConfigurationClass(BaseModel):
Augment the ``BaseModel`` class with 1) dict cast unwrapping and 2) inherited validators.
def __iter__(self, *args, **kwargs):
Make ``dict(model)`` equal to ``model.dict()``.
Normally this function only only gives the child model references instead of unwrapping them.
return iter(self.dict().items())
def get_validators(cls):
yield cls.validate
def validate(cls, value):
if isinstance(value, cls):
return value
return cls(**dict_validator(value))
def dict(self, *args, **kwargs):
kwargs['by_alias'] = True
return super().dict(*args, **kwargs)
class Config:
allow_population_by_field_name = True
class BootstrapperInit(BaseConfigurationClass):
bootstrap_timeout: float = 30.0
class DispersyBootstrapperInit(BootstrapperInit):
ip_addresses: Optional[List[tuple]] = None
dns_addresses: Optional[List[tuple]] = None
bootstrap_timeout: float = 30.0
@validator('ip_addresses', always=True)
def _default_ip_addresses(cls, v):
return DISPERSY_BOOTSTRAPPER['init']['ip_addresses'] if v is None else v
@validator('dns_addresses', always=True)
def _default_dns_addresses(cls, v):
return DISPERSY_BOOTSTRAPPER['init']['dns_addresses'] if v is None else v
class Bootstrapper(BaseConfigurationClass):
cls: str = Field(alias='class')
init: BootstrapperInit = Field(..., arbitrary_types_allowed=True)
class DispersyBootstrapper(Bootstrapper):
def __init__(self, **kwargs):
kwargs["class"] = "DispersyBootstrapper"
class UDPBroadcastBootstrapper(Bootstrapper):
def __init__(self, **kwargs):
kwargs["class"] = "UDPBroadcastBootstrapper"
class Interface(BaseConfigurationClass):
interface: str = default["interfaces"][0]["interface"]
ip: str = default["interfaces"][0]["ip"]
port: int = default["interfaces"][0]["port"]
class Logger(BaseConfigurationClass):
level: str = default["logger"]["level"]
@validator("level", always=True)
def validate_level(cls, v):
assert v in allowed, f"Illegal log level {v} specified! Should be one of {allowed}!"
return v
class Key(BaseConfigurationClass):
alias: str = default["keys"][0]["alias"]
class PreloadedKey(Key):
bin: str
def generate_ephemeral_key():
return base64.b64encode(default_eccrypto.generate_key("curve25519").key_to_bin()).decode()
class EphemeralKey(Key):
bin: str = Field(frozen=True, default_factory=generate_ephemeral_key)
class FileKey(Key):
generation: str = default["keys"][0]["generation"]
file: str = default["keys"][0]["file"]
class Walker(BaseConfigurationClass):
strategy: str = Field(...)
peers: int = 20
init: dict = Field(default_factory=dict)
@validator("peers", always=True)
def validate_peer_count(cls, v):
assert v >= 0 or v == -1, f"A walker's peer count must be >= 0 or set to -1 (infinite), got: {v}!"
return v
class RandomWalk(Walker):
def __init__(self, **kwargs):
kwargs["strategy"] = "RandomWalk"
if "init" not in kwargs:
kwargs["init"] = {"timeout": 3.0}
class EdgeWalk(Walker):
def __init__(self, **kwargs):
kwargs["strategy"] = "EdgeWalk"
class RandomChurn(Walker):
def __init__(self, **kwargs):
kwargs["strategy"] = "RandomChurn"
if "peers" not in kwargs:
kwargs["peers"] = -1
if "init" not in kwargs:
kwargs["init"] = {'sample_size': 8, 'ping_interval': 10.0, 'inactive_time': 27.5, 'drop_time': 57.5}
class PeriodicSimilarity(Walker):
def __init__(self, **kwargs):
kwargs["strategy"] = "PeriodicSimilarity"
if "peers" not in kwargs:
kwargs["peers"] = -1
class PingChurn(Walker):
def __init__(self, **kwargs):
kwargs["strategy"] = "PingChurn"
if "peers" not in kwargs:
kwargs["peers"] = -1
class Overlay(BaseConfigurationClass):
cls: str = Field(alias='class')
key: str = default["keys"][0]["alias"]
walkers: List[Walker] = [RandomWalk()]
bootstrappers: List[Bootstrapper] = [DispersyBootstrapper(init=DispersyBootstrapperInit())]
initialize: dict = Field(default_factory=dict)
on_start: List[tuple] = []
class IPv8Configuration(BaseConfigurationClass):
Main pydantic IPv8 configuration model, can be fed directly into the ``IPv8`` class constructor.
interfaces: List[Interface] = [Interface()]
key_aliases: List[Key] = Field([FileKey()], alias="keys")
logger: Logger = Logger()
working_directory: str = default["working_directory"]
walker_interval: float = default["walker_interval"]
overlays: List[Overlay] = [
Overlay(cls="DiscoveryCommunity", walkers=[RandomWalk(), RandomChurn(), PeriodicSimilarity()]),
Overlay(cls="HiddenTunnelCommunity", initialize=default["overlays"][1]["initialize"],
Overlay(cls="DHTDiscoveryCommunity", walkers=[RandomWalk(), PingChurn()])
@validator("interfaces", always=True)
def validate_interfaces(cls, v):
it_names = [spec.interface for spec in v]
assert len(v) == len(set(it_names)), f"Duplicate interface names specified: {', '.join(it_names)}!"
return v
def format_schema_recursive(*base_models: Type[BaseModel]) -> dict:
Create a schema (``dict``) describing a BaseModel class.
For example, the following pretty-prints the schema of ``IPv8Configuration``:
from json import dumps
print(dumps(format_schema_recursive(IPv8Configuration), indent=4))
Note that the schema shows the expected interface, which is not necessarily equal to its implementations!
For example, only ``A`` and ``Main`` are shown when calling ``format_schema_recursive(Main)`` using this code:
class A(BaseModel):
property: number
class B(A):
property2: number
class Main(BaseModel):
a_implementation: A
In the above example, you could force generation for ``B`` by calling ``format_schema_recursive(Main, B)``.
:param base_models: the models to generate the schema for
:returns: the mapping of types for the given collection of base models
def format_schema_single(model_schema: dict, refs: set) -> Union[dict, list, str]:
Try to convert the types given by pydantic into actual Python objects. Turns objects into dicts, arrays into
lists, custom definitions into their BaseModel class names, and keeps whatever other primitive is used.
Note: this function is recursive!
:param model_schema: the result of ``<BaseModel>.schema()``
:param refs: a reference to a set in which to add encountered references
:returns: either a dict describing the schema or a str describing a single primitive type
properties: Union[str, List[Any], Dict[Any, Any]]
if "type" in model_schema:
is_object = model_schema["type"] in ["object", "array"]
name = model_schema["title"] if is_object else model_schema["type"]
if "properties" in model_schema:
# It's a dict!
properties = {k: format_schema_single(v, refs) for k, v in model_schema["properties"].items()}
elif "items" in model_schema:
# It's a list!
if '$ref' in model_schema['items']:
properties = [model_schema['items']['$ref']]
# It's an empty list!
properties = []
elif is_object:
# It's an empty dict!
properties = {}
# It's a "none of the above or below" (number, string, etc.)!
return f"{name.lower()}"
# It's a custom definition!
if '$ref' in model_schema:
properties = model_schema["$ref"]
elif len(model_schema["allOf"]) == 1:
properties = model_schema["allOf"][0]["$ref"]
properties = []
for m in model_schema["allOf"]:
return properties
finished_refs: Set[str] = set()
known_refs = {f"#/definitions/{base_model.__name__}" for base_model in base_models}
out = {}
while len(finished_refs) != len(known_refs):
# 1. Fetch the Python class "model" belonging to the (str) definition name
next_ref = list(known_refs - finished_refs)[0]
assert next_ref.startswith("#/definitions/")
clsname = next_ref[14:]
model = None
if clsname in globals():
# Easy: the class has already been imported.
model = globals()[clsname]
# Hard: we need to find the class definition in the loaded modules.
for module in sys.modules:
if clsname in dir(sys.modules[module]):
model = getattr(sys.modules[module], clsname)
# Extra hard: it is still possible to make exotic definitions that escape this search.
# We'll just let the caller manually work around this (e.g., by adding the class to globals).
assert model is not None, f"Failed to locate class belonging to {next_ref}!"
# 2. Format the model class into a dict
out[model.__name__] = ({ i.value for i in model} if issubclass(model, Enum)
else format_schema_single(model.schema(), known_refs))
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment