Created
January 15, 2020 15:52
-
-
Save espdev/44a26cbf82f131c79b81123c5e40c147 to your computer and use it in GitHub Desktop.
pydantic_algo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name='laplace' params=LaplaceAlgoParams(gradient_sigma=1.0, laplacian_sigma=2.0) | |
name='gauss' params=GaussAlgoParams(gaussian_sigma=1.5) | |
1 validation error for AlgoSettings | |
name | |
Cannot find algorithm 'lol'. Existing algorithms: ['laplace', 'gauss'] (type=value_error) | |
1 validation error for AlgoSettings | |
params | |
<class '__main__.GaussAlgoParams'> parameters type is invalid for 'laplace' algorithm. Parameters type must be a subclass of <class '__main__.LaplaceAlgoParams'> (type=type_error) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import abc | |
from typing import Type, NamedTuple, Dict, Optional | |
from pydantic import BaseModel, validator, confloat, ValidationError | |
class AlgoRegistryInfo(NamedTuple): | |
cls: Type[AlgoBase] | |
params_cls: Type[AlgoParamsBase] | |
algo_registry: Dict[str, AlgoRegistryInfo] = {} | |
class AlgoParamsBase(BaseModel, abc.ABC): | |
pass | |
class AlgoBase(abc.ABC): | |
def __init__(self, params: AlgoParamsBase) -> None: | |
self.params = params | |
@abc.abstractmethod | |
def __call__(self, data): | |
pass | |
def register_algo(name: str, params_cls: Type[AlgoParamsBase]): | |
def decorator(cls: Type[AlgoBase]): | |
algo_registry[name] = AlgoRegistryInfo(cls, params_cls) | |
return decorator | |
class LaplaceAlgoParams(AlgoParamsBase): | |
gradient_sigma: confloat(ge=0.2, le=100.0) = 1.0 | |
laplacian_sigma: confloat(ge=0.2, le=50.0) = 1.0 | |
@register_algo('laplace', LaplaceAlgoParams) | |
class Laplace(AlgoBase): | |
def __call__(self, data): | |
raise NotImplementedError | |
class GaussAlgoParams(AlgoParamsBase): | |
gaussian_sigma: confloat(ge=0.0, le=50.0) = 1.5 | |
@register_algo('gauss', GaussAlgoParams) | |
class Gauss(AlgoBase): | |
def __call__(self, data): | |
raise NotImplementedError | |
class AlgoSettings(BaseModel): | |
name: str | |
params: Optional[AlgoParamsBase] | |
@validator('name') | |
def algo_exists(cls, name: str): | |
if name not in algo_registry: | |
raise ValueError(( | |
f"Cannot find algorithm '{name}'. " | |
f"Existing algorithms: {list(algo_registry)}")) | |
return name | |
@validator('params', pre=True, always=True) | |
def algo_params(cls, params, values): | |
name = values.get('name') | |
if not name: | |
return params | |
params_cls = algo_registry[name].params_cls | |
if not params: | |
# Default parameters for the algorithm | |
return params_cls() | |
if isinstance(params, params_cls): | |
return params | |
if isinstance(params, dict): | |
# When deserialization from dict | |
return params_cls(**params) | |
raise TypeError(( | |
f"{type(params)} parameters type is invalid for '{name}' algorithm. " | |
f"Parameters type must be a subclass of {params_cls}" | |
)) | |
if __name__ == '__main__': | |
s_laplace = AlgoSettings(name='laplace', params=LaplaceAlgoParams(laplacian_sigma=2.0)) | |
d_laplace = s_laplace.dict() | |
print(s_laplace) | |
assert s_laplace == AlgoSettings(**d_laplace) | |
s_gauss = AlgoSettings(name='gauss') | |
d_gauss = s_gauss.dict() | |
print(s_gauss) | |
assert s_gauss == AlgoSettings(**d_gauss) | |
try: | |
AlgoSettings(name='lol') | |
except ValidationError as err: | |
print(err) | |
try: | |
AlgoSettings(name='laplace', params=GaussAlgoParams()) | |
except ValidationError as err: | |
print(err) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment