Skip to content

Instantly share code, notes, and snippets.

@EnriqueSoria
Forked from rochacbruno/validate_dataclass.py
Last active July 18, 2024 13:25
Show Gist options
  • Save EnriqueSoria/c270ff18a4793df12a0b11ba68bae7f6 to your computer and use it in GitHub Desktop.
Save EnriqueSoria/c270ff18a4793df12a0b11ba68bae7f6 to your computer and use it in GitHub Desktop.
Validate Dataclass Python
import logging
from dataclasses import dataclass
from typing import Union, List
logger = logging.getLogger(__name__)
class Validations:
def __post_init__(self):
"""
Run validation methods if declared.
The validation method can be a simple check
that raises ValueError or a transformation to
the field value.
The validation is performed by calling a function named:
`validate_<field_name>(self, value, field) -> field.type`
Finally, calls (if defined) `validate(self)` for validations that depend on other fields
"""
for name, field in self.__dataclass_fields__.items():
validator_name = f"validate_{name}"
if method := getattr(self, validator_name, None):
logger.debug(f"Calling validator: {validator_name}")
new_value = method(getattr(self, name), field=field)
setattr(self, name, new_value)
if (validate := getattr(self, "validate", None)) and callable(validate):
logger.debug(f"Calling validator: validate")
validate()
@dataclass
class Product(Validations):
name: str
password: str
tags: Union[str, List[str]]
def validate_name(self, value, **_) -> str:
if len(value) < 3 or len(value) > 20:
raise ValueError("name must have between 3 and 20 chars.")
return value
def validate_tags(self, value, **_) -> List[str]:
"""Ensure tags are always List[str] even if "tag1,tag2" is passed"""
if isinstance(value, str):
value = [v.strip() for v in value.split(",")]
return value
def validate(self):
if self.name in self.password:
raise ValueError("Your name can't be part of your passowrd")
if __name__ == "__main__":
product = Product(name="product", password="123", tags="tag1, tag2")
assert product.tags == ["tag1", "tag2"] # transformed to List[str]
try:
product = Product(name="pr", password="123", tags="tag1, tag2, tag3")
except ValueError as e:
assert str(e) == "name must have between 3 and 20 chars."
try:
product = Product(name="name", password="name123", tags="tag1, tag2, tag3")
except ValueError as e:
assert str(e) == "Your name can't be part of your passowrd"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment