Last active
May 15, 2024 12:40
-
-
Save pirate/0d539b50d2789898f2c34899390fd910 to your computer and use it in GitHub Desktop.
Pydantic config loader and dumper with dynamic defaults based on previous values and support for TOML, INI, JSON, Env, and ArchiveBox schema formats.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import os | |
import sys | |
import toml | |
import json | |
import orjson | |
import platform | |
import inspect | |
import tomllib | |
import ini2toml | |
from typing import Callable, Any, Optional, Pattern, Type, Tuple, Dict, List | |
from pathlib import Path | |
from pydantic import BaseModel, Field, FieldValidationInfo, AliasChoices, model_validator, FilePath, DirectoryPath, computed_field, TypeAdapter | |
from pydantic.fields import FieldInfo | |
from pydantic_settings import BaseSettings, SettingsConfigDict, PydanticBaseSettingsSource | |
from pydantic_settings.sources import InitSettingsSource, ConfigFileSourceMixin, TomlConfigSettingsSource | |
from pydantic.json_schema import GenerateJsonSchema | |
from pydantic_core import PydanticOmit, core_schema, to_jsonable_python, ValidationError | |
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue | |
import ini_to_toml # ini_to_toml.py source: https://gist.github.com/pirate/7b4a07e77dc24ae829546ccf4c72d115 | |
class JSONSchemaWithLambdas(GenerateJsonSchema): | |
def encode_default(self, default: Any) -> Any: | |
"""Encode lambda functions in default values properly""" | |
config = self._config | |
if isinstance(default, Callable): | |
return '{{lambda ' + inspect.getsource(default).split('=lambda ')[-1].strip()[:-1] + '}}' | |
return to_jsonable_python( | |
default, | |
timedelta_mode=config.ser_json_timedelta, | |
bytes_mode=config.ser_json_bytes, | |
serialize_unknown=True | |
) | |
def load_config_file(ini_path: Path) -> Dict[str, str]: | |
"""load the ini-formatted config file from OUTPUT_DIR/Archivebox.conf""" | |
if ini_path.exists(): | |
config_file = ConfigParser() | |
config_file.optionxform = str | |
config_file.read(ini_path) | |
# flatten into one namespace | |
config_file_vars = { | |
key.upper(): val | |
for section, options in config_file.items() | |
for key, val in options.items() | |
} | |
print('[i] Loaded config file', os.path.abspath(ini_path)) | |
print(config_file_vars) | |
return config_file_vars | |
return {} | |
class IniConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin): | |
""" | |
A source class that loads variables from the collection ArchiveBox.conf INI file | |
""" | |
def __init__(self, settings_cls: type[BaseSettings], ini_file: Path): | |
self.ini_file_path = ini_file or settings_cls.model_config.get('ini_file') | |
self.ini_data = self._read_files(self.ini_file_path) | |
super().__init__(settings_cls, self.ini_data) | |
def _read_file(self, file_path: Path) -> dict[str, Any]: | |
return load_config_file(file_path) | |
class ModelWithDefaults(BaseSettings): | |
model_config = SettingsConfigDict(validate_default=False, case_sensitive=False, extra='ignore') | |
@model_validator(mode='after') | |
def fill_defaults(self): | |
"""Populate any unset values using function provided as their default""" | |
for key, field in self.model_fields.items(): | |
value = getattr(self, key) | |
if isinstance(value, Callable): | |
# if value is a function, execute it to get the actual value | |
fallback_value = field.default(self.dict(exclude_unset=True)) | |
# check to make sure default factory return value matches type annotation | |
TypeAdapter(field.annotation).validate_python(fallback_value) | |
# set generated default value as final validated value | |
setattr(self, key, fallback_value) | |
return self | |
def as_json(self, model_fields=True, computed_fields=True): | |
output_dict = {} | |
for section in self.__class__.__mro__[1:]: | |
if not section.__name__.isupper(): | |
break | |
output_dict[section.__name__] = output_dict.get(section.__name__) or {} | |
include = {} | |
if model_fields: | |
include.update(**section.model_fields) | |
if computed_fields: | |
include.update(**section.model_computed_fields) | |
output_dict[section.__name__].update(json.loads(section.json(self, include=include))) | |
return output_dict | |
def as_toml(self, model_fields=True, computed_fields=True): | |
output_text = '' | |
for section in self.__class__.__mro__[1:]: | |
if not section.__name__.isupper(): | |
break | |
include = {} | |
if model_fields: | |
include.update(**section.model_fields) | |
if computed_fields: | |
include.update(**section.model_computed_fields) | |
output_text += ( | |
f'[{section.__name__}]\n' + | |
toml.dumps(json.loads(section.json(self, include=include))) + '\n' | |
) | |
return output_text | |
def as_legacy_schema(self, model_fields=True, computed_fields=True): | |
"""Convert a newer Pydantic Settings BaseModel into the old-style archivebox.config CONFIG_SCHEMA format""" | |
schemas = {} | |
include = {} | |
if model_fields: | |
include.update(**self.model_fields) | |
if computed_fields: | |
include.update(**self.model_computed_fields) | |
for key, field in include.items(): | |
key = key.upper() | |
for cls in self.__class__.__mro__[1:]: | |
if key in cls.model_fields or key in cls.model_computed_fields: | |
defining_class = cls | |
break | |
schemas[defining_class.__name__] = schemas.get(defining_class.__name__) or {} | |
schemas[defining_class.__name__][key] = { | |
'value': getattr(self, key), | |
'type': str(field.annotation.__name__).lower() if hasattr(field, 'annotation') else str(field.return_type).lower(), | |
'default': field.default if hasattr(field, 'default') else field.wrapped_property.fget, # inspect.getsource(field.wrapped_property.fget).split('def ', 1)[-1].split('\n', 1)[-1].strip().strip('return '), | |
'aliases': (getattr(field.json_schema_extra.get('aliases', {}), 'choices') or []) if getattr(field, 'json_schema_extra') else [], | |
} | |
return schemas | |
@classmethod | |
def settings_customise_sources( | |
cls, | |
settings_cls: Type[BaseSettings], | |
init_settings: PydanticBaseSettingsSource, | |
env_settings: PydanticBaseSettingsSource, | |
dotenv_settings: PydanticBaseSettingsSource, | |
file_secret_settings: PydanticBaseSettingsSource, | |
) -> Tuple[PydanticBaseSettingsSource, ...]: | |
ARCHIVEBOX_CONFIG_FILE = Path('/Users/squash/Local/Code/archiveboxes/ArchiveBox/data/ArchiveBox.conf') | |
ARCHIVEBOX_CONFIG_FILE_TOML = ARCHIVEBOX_CONFIG_FILE.parent / f'.ArchiveBox.toml' | |
try: | |
return ( | |
init_settings, | |
env_settings, | |
TomlConfigSettingsSource(settings_cls, toml_file=ARCHIVEBOX_CONFIG_FILE), | |
# IniConfigSettingsSource(settings_cls, ini_file=Path('.') / 'ArchiveBox.conf', | |
) | |
except tomllib.TOMLDecodeError: | |
toml_str = ini_to_toml.convert(ARCHIVEBOX_CONFIG_FILE.read_text()) | |
with open(ARCHIVEBOX_CONFIG_FILE_TOML, 'w+') as f: | |
f.write(toml_str) | |
return ( | |
init_settings, | |
env_settings, | |
TomlConfigSettingsSource(settings_cls, toml_file=ARCHIVEBOX_CONFIG_FILE_TOML), | |
# IniConfigSettingsSource(settings_cls, ini_file=Path('.') / 'ArchiveBox.conf', | |
) | |
class SHELL_CONFIG(ModelWithDefaults): | |
IS_TTY: bool = Field(default=lambda c: sys.stdout.isatty()) | |
USE_COLOR: bool = Field(default=lambda c: c['IS_TTY']) | |
SHOW_PROGRESS: bool = Field(default=lambda c: c['IS_TTY'] and (platform.system() != 'Darwin')) | |
IN_DOCKER: bool = Field(default=False) | |
IN_QEMU: bool = Field(default=False) | |
PUID: int = Field(default=lambda c: os.getuid()) | |
PGID: int = Field(default=lambda c: os.getgid()) | |
class GENERAL_CONFIG(ModelWithDefaults): | |
# OUTPUT_DIR: DirectoryPath | |
CONFIG_FILE: FilePath = Field(default=lambda c: c['OUTPUT_DIR'] / 'ArchiveBox.conf') | |
ONLY_NEW: bool = Field(default=True) | |
TIMEOUT: int = Field(default=60) | |
MEDIA_TIMEOUT: int = Field(default=3600) | |
ENFORCE_ATOMIC_WRITES: bool = Field(default=True) | |
OUTPUT_PERMISSIONS: str = Field(default='644') | |
RESTRICT_FILE_NAMES: str = Field(default='windows') | |
URL_DENYLIST: Pattern = Field(default=re.compile(r'\.(css|js|otf|ttf|woff|woff2|gstatic\.com|googleapis\.com/css)(\?.*)?$'), aliases=AliasChoices('URL_BLACKLIST')) | |
URL_ALLOWLIST: Pattern = Field(default=re.compile(r''), aliases=AliasChoices('URL_WHITELIST')) | |
ADMIN_USERNAME: Optional[str] = Field(default=None, min_length=1, max_length=63, pattern=r'^[\S]+$') | |
ADMIN_PASSWORD: Optional[str] = Field(default=None, min_length=1, max_length=63) | |
TAG_SEPARATOR_PATTERN: Pattern = Field(default=re.compile(r'[,]')) | |
@computed_field | |
@property | |
def OUTPUT_DIR(self) -> DirectoryPath: | |
return Path('.').resolve() | |
CONFIG_SECTIONS = (GENERAL_CONFIG, SHELL_CONFIG) | |
class USER_CONFIG(*CONFIG_SECTIONS): | |
pass | |
if __name__ == '__main__': | |
# print(ShellConfig(**{'IS_TTY': False, 'PGID': 911}).model_dump()) | |
# print(json.dumps(SHELL_CONFIG.model_json_schema(schema_generator=JSONSchemaWithLambdas), indent=4)) | |
# print(json.dumps(GENERAL_CONFIG.model_json_schema(schema_generator=JSONSchemaWithLambdas), indent=4)) | |
print() | |
# os.environ['PGID'] = '422' | |
os.environ['URL_ALLOWLIST'] = r'worked!!!!!\\.com' | |
config = USER_CONFIG(**{'SHOW_PROGRESS': False, 'ADMIN_USERNAME': 'kip', 'PGID': 911}) | |
print('==========archivebox.config.CONFIG_SCHEMA======================') | |
print(json.dumps(config.as_legacy_schema(), indent=4, default=str)) | |
print('==========JSON=================================================') | |
# print(config.__class__.__name__, '=', config.model_dump_json(indent=4)) | |
print(json.dumps(config.as_json(), indent=4)) | |
print('==========TOML=================================================') | |
print(config.as_toml()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment