Skip to content

Instantly share code, notes, and snippets.

@espdev
Created January 15, 2020 15:52
Show Gist options
  • Save espdev/44a26cbf82f131c79b81123c5e40c147 to your computer and use it in GitHub Desktop.
Save espdev/44a26cbf82f131c79b81123c5e40c147 to your computer and use it in GitHub Desktop.
pydantic_algo
name='laplace' params=LaplaceAlgoParams(gradient_sigma=1.0, laplacian_sigma=2.0)
name='gauss' params=GaussAlgoParams(gaussian_sigma=1.5)
1 validation error for AlgoSettings
name
Cannot find algorithm 'lol'. Existing algorithms: ['laplace', 'gauss'] (type=value_error)
1 validation error for AlgoSettings
params
<class '__main__.GaussAlgoParams'> parameters type is invalid for 'laplace' algorithm. Parameters type must be a subclass of <class '__main__.LaplaceAlgoParams'> (type=type_error)
from __future__ import annotations
import abc
from typing import Type, NamedTuple, Dict, Optional
from pydantic import BaseModel, validator, confloat, ValidationError
class AlgoRegistryInfo(NamedTuple):
cls: Type[AlgoBase]
params_cls: Type[AlgoParamsBase]
algo_registry: Dict[str, AlgoRegistryInfo] = {}
class AlgoParamsBase(BaseModel, abc.ABC):
pass
class AlgoBase(abc.ABC):
def __init__(self, params: AlgoParamsBase) -> None:
self.params = params
@abc.abstractmethod
def __call__(self, data):
pass
def register_algo(name: str, params_cls: Type[AlgoParamsBase]):
def decorator(cls: Type[AlgoBase]):
algo_registry[name] = AlgoRegistryInfo(cls, params_cls)
return decorator
class LaplaceAlgoParams(AlgoParamsBase):
gradient_sigma: confloat(ge=0.2, le=100.0) = 1.0
laplacian_sigma: confloat(ge=0.2, le=50.0) = 1.0
@register_algo('laplace', LaplaceAlgoParams)
class Laplace(AlgoBase):
def __call__(self, data):
raise NotImplementedError
class GaussAlgoParams(AlgoParamsBase):
gaussian_sigma: confloat(ge=0.0, le=50.0) = 1.5
@register_algo('gauss', GaussAlgoParams)
class Gauss(AlgoBase):
def __call__(self, data):
raise NotImplementedError
class AlgoSettings(BaseModel):
name: str
params: Optional[AlgoParamsBase]
@validator('name')
def algo_exists(cls, name: str):
if name not in algo_registry:
raise ValueError((
f"Cannot find algorithm '{name}'. "
f"Existing algorithms: {list(algo_registry)}"))
return name
@validator('params', pre=True, always=True)
def algo_params(cls, params, values):
name = values.get('name')
if not name:
return params
params_cls = algo_registry[name].params_cls
if not params:
# Default parameters for the algorithm
return params_cls()
if isinstance(params, params_cls):
return params
if isinstance(params, dict):
# When deserialization from dict
return params_cls(**params)
raise TypeError((
f"{type(params)} parameters type is invalid for '{name}' algorithm. "
f"Parameters type must be a subclass of {params_cls}"
))
if __name__ == '__main__':
s_laplace = AlgoSettings(name='laplace', params=LaplaceAlgoParams(laplacian_sigma=2.0))
d_laplace = s_laplace.dict()
print(s_laplace)
assert s_laplace == AlgoSettings(**d_laplace)
s_gauss = AlgoSettings(name='gauss')
d_gauss = s_gauss.dict()
print(s_gauss)
assert s_gauss == AlgoSettings(**d_gauss)
try:
AlgoSettings(name='lol')
except ValidationError as err:
print(err)
try:
AlgoSettings(name='laplace', params=GaussAlgoParams())
except ValidationError as err:
print(err)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment