Skip to content

Instantly share code, notes, and snippets.

@pirate
Last active May 15, 2024 12:40
Show Gist options
  • Save pirate/0d539b50d2789898f2c34899390fd910 to your computer and use it in GitHub Desktop.
Save pirate/0d539b50d2789898f2c34899390fd910 to your computer and use it in GitHub Desktop.
Pydantic config loader and dumper with dynamic defaults based on previous values and support for TOML, INI, JSON, Env, and ArchiveBox schema formats.
import re
import os
import sys
import toml
import json
import orjson
import platform
import inspect
import tomllib
import ini2toml
from typing import Callable, Any, Optional, Pattern, Type, Tuple, Dict, List
from pathlib import Path
from pydantic import BaseModel, Field, FieldValidationInfo, AliasChoices, model_validator, FilePath, DirectoryPath, computed_field, TypeAdapter
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, SettingsConfigDict, PydanticBaseSettingsSource
from pydantic_settings.sources import InitSettingsSource, ConfigFileSourceMixin, TomlConfigSettingsSource
from pydantic.json_schema import GenerateJsonSchema
from pydantic_core import PydanticOmit, core_schema, to_jsonable_python, ValidationError
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
import ini_to_toml # ini_to_toml.py source: https://gist.github.com/pirate/7b4a07e77dc24ae829546ccf4c72d115
class JSONSchemaWithLambdas(GenerateJsonSchema):
def encode_default(self, default: Any) -> Any:
"""Encode lambda functions in default values properly"""
config = self._config
if isinstance(default, Callable):
return '{{lambda ' + inspect.getsource(default).split('=lambda ')[-1].strip()[:-1] + '}}'
return to_jsonable_python(
default,
timedelta_mode=config.ser_json_timedelta,
bytes_mode=config.ser_json_bytes,
serialize_unknown=True
)
def load_config_file(ini_path: Path) -> Dict[str, str]:
"""load the ini-formatted config file from OUTPUT_DIR/Archivebox.conf"""
if ini_path.exists():
config_file = ConfigParser()
config_file.optionxform = str
config_file.read(ini_path)
# flatten into one namespace
config_file_vars = {
key.upper(): val
for section, options in config_file.items()
for key, val in options.items()
}
print('[i] Loaded config file', os.path.abspath(ini_path))
print(config_file_vars)
return config_file_vars
return {}
class IniConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
"""
A source class that loads variables from the collection ArchiveBox.conf INI file
"""
def __init__(self, settings_cls: type[BaseSettings], ini_file: Path):
self.ini_file_path = ini_file or settings_cls.model_config.get('ini_file')
self.ini_data = self._read_files(self.ini_file_path)
super().__init__(settings_cls, self.ini_data)
def _read_file(self, file_path: Path) -> dict[str, Any]:
return load_config_file(file_path)
class ModelWithDefaults(BaseSettings):
model_config = SettingsConfigDict(validate_default=False, case_sensitive=False, extra='ignore')
@model_validator(mode='after')
def fill_defaults(self):
"""Populate any unset values using function provided as their default"""
for key, field in self.model_fields.items():
value = getattr(self, key)
if isinstance(value, Callable):
# if value is a function, execute it to get the actual value
fallback_value = field.default(self.dict(exclude_unset=True))
# check to make sure default factory return value matches type annotation
TypeAdapter(field.annotation).validate_python(fallback_value)
# set generated default value as final validated value
setattr(self, key, fallback_value)
return self
def as_json(self, model_fields=True, computed_fields=True):
output_dict = {}
for section in self.__class__.__mro__[1:]:
if not section.__name__.isupper():
break
output_dict[section.__name__] = output_dict.get(section.__name__) or {}
include = {}
if model_fields:
include.update(**section.model_fields)
if computed_fields:
include.update(**section.model_computed_fields)
output_dict[section.__name__].update(json.loads(section.json(self, include=include)))
return output_dict
def as_toml(self, model_fields=True, computed_fields=True):
output_text = ''
for section in self.__class__.__mro__[1:]:
if not section.__name__.isupper():
break
include = {}
if model_fields:
include.update(**section.model_fields)
if computed_fields:
include.update(**section.model_computed_fields)
output_text += (
f'[{section.__name__}]\n' +
toml.dumps(json.loads(section.json(self, include=include))) + '\n'
)
return output_text
def as_legacy_schema(self, model_fields=True, computed_fields=True):
"""Convert a newer Pydantic Settings BaseModel into the old-style archivebox.config CONFIG_SCHEMA format"""
schemas = {}
include = {}
if model_fields:
include.update(**self.model_fields)
if computed_fields:
include.update(**self.model_computed_fields)
for key, field in include.items():
key = key.upper()
for cls in self.__class__.__mro__[1:]:
if key in cls.model_fields or key in cls.model_computed_fields:
defining_class = cls
break
schemas[defining_class.__name__] = schemas.get(defining_class.__name__) or {}
schemas[defining_class.__name__][key] = {
'value': getattr(self, key),
'type': str(field.annotation.__name__).lower() if hasattr(field, 'annotation') else str(field.return_type).lower(),
'default': field.default if hasattr(field, 'default') else field.wrapped_property.fget, # inspect.getsource(field.wrapped_property.fget).split('def ', 1)[-1].split('\n', 1)[-1].strip().strip('return '),
'aliases': (getattr(field.json_schema_extra.get('aliases', {}), 'choices') or []) if getattr(field, 'json_schema_extra') else [],
}
return schemas
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
ARCHIVEBOX_CONFIG_FILE = Path('/Users/squash/Local/Code/archiveboxes/ArchiveBox/data/ArchiveBox.conf')
ARCHIVEBOX_CONFIG_FILE_TOML = ARCHIVEBOX_CONFIG_FILE.parent / f'.ArchiveBox.toml'
try:
return (
init_settings,
env_settings,
TomlConfigSettingsSource(settings_cls, toml_file=ARCHIVEBOX_CONFIG_FILE),
# IniConfigSettingsSource(settings_cls, ini_file=Path('.') / 'ArchiveBox.conf',
)
except tomllib.TOMLDecodeError:
toml_str = ini_to_toml.convert(ARCHIVEBOX_CONFIG_FILE.read_text())
with open(ARCHIVEBOX_CONFIG_FILE_TOML, 'w+') as f:
f.write(toml_str)
return (
init_settings,
env_settings,
TomlConfigSettingsSource(settings_cls, toml_file=ARCHIVEBOX_CONFIG_FILE_TOML),
# IniConfigSettingsSource(settings_cls, ini_file=Path('.') / 'ArchiveBox.conf',
)
class SHELL_CONFIG(ModelWithDefaults):
IS_TTY: bool = Field(default=lambda c: sys.stdout.isatty())
USE_COLOR: bool = Field(default=lambda c: c['IS_TTY'])
SHOW_PROGRESS: bool = Field(default=lambda c: c['IS_TTY'] and (platform.system() != 'Darwin'))
IN_DOCKER: bool = Field(default=False)
IN_QEMU: bool = Field(default=False)
PUID: int = Field(default=lambda c: os.getuid())
PGID: int = Field(default=lambda c: os.getgid())
class GENERAL_CONFIG(ModelWithDefaults):
# OUTPUT_DIR: DirectoryPath
CONFIG_FILE: FilePath = Field(default=lambda c: c['OUTPUT_DIR'] / 'ArchiveBox.conf')
ONLY_NEW: bool = Field(default=True)
TIMEOUT: int = Field(default=60)
MEDIA_TIMEOUT: int = Field(default=3600)
ENFORCE_ATOMIC_WRITES: bool = Field(default=True)
OUTPUT_PERMISSIONS: str = Field(default='644')
RESTRICT_FILE_NAMES: str = Field(default='windows')
URL_DENYLIST: Pattern = Field(default=re.compile(r'\.(css|js|otf|ttf|woff|woff2|gstatic\.com|googleapis\.com/css)(\?.*)?$'), aliases=AliasChoices('URL_BLACKLIST'))
URL_ALLOWLIST: Pattern = Field(default=re.compile(r''), aliases=AliasChoices('URL_WHITELIST'))
ADMIN_USERNAME: Optional[str] = Field(default=None, min_length=1, max_length=63, pattern=r'^[\S]+$')
ADMIN_PASSWORD: Optional[str] = Field(default=None, min_length=1, max_length=63)
TAG_SEPARATOR_PATTERN: Pattern = Field(default=re.compile(r'[,]'))
@computed_field
@property
def OUTPUT_DIR(self) -> DirectoryPath:
return Path('.').resolve()
CONFIG_SECTIONS = (GENERAL_CONFIG, SHELL_CONFIG)
class USER_CONFIG(*CONFIG_SECTIONS):
pass
if __name__ == '__main__':
# print(ShellConfig(**{'IS_TTY': False, 'PGID': 911}).model_dump())
# print(json.dumps(SHELL_CONFIG.model_json_schema(schema_generator=JSONSchemaWithLambdas), indent=4))
# print(json.dumps(GENERAL_CONFIG.model_json_schema(schema_generator=JSONSchemaWithLambdas), indent=4))
print()
# os.environ['PGID'] = '422'
os.environ['URL_ALLOWLIST'] = r'worked!!!!!\\.com'
config = USER_CONFIG(**{'SHOW_PROGRESS': False, 'ADMIN_USERNAME': 'kip', 'PGID': 911})
print('==========archivebox.config.CONFIG_SCHEMA======================')
print(json.dumps(config.as_legacy_schema(), indent=4, default=str))
print('==========JSON=================================================')
# print(config.__class__.__name__, '=', config.model_dump_json(indent=4))
print(json.dumps(config.as_json(), indent=4))
print('==========TOML=================================================')
print(config.as_toml())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment