Skip to content

Instantly share code, notes, and snippets.

@RahulDas-dev
Last active November 5, 2023 07:10
Show Gist options
  • Save RahulDas-dev/9ca784d51e3bf4e883ad8e0944db95b3 to your computer and use it in GitHub Desktop.
Save RahulDas-dev/9ca784d51e3bf4e883ad8e0944db95b3 to your computer and use it in GitHub Desktop.
Save or Load Complex nested Pydantic model into or from sqlalclemy-fileds

Save or Load Complex nested Pydantic model into or from sqlalclemy-fileds

This gist demonstrates a custom SQLAlchemy column type that allows you to save and load Pydantic models directly in your SQLAlchemy database. It simplifies the process of storing and retrieving complex pydantic data models in your database without manual conversion to JSON or dictionaries.

kindly Check Git Repo or follow steps bellow

Steps

  1. Building custom Column Type
  2. Building a nested data structure using Pydantic model
  3. Building a Sqlalchemy Table models using 1 and 2
  4. Insert and select Script for testing

Building custom Column Type [custom_column.py]

from pydantic import BaseModel, TypeAdapter  # pydantic version > 2.0.0
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import JSON, TypeDecorator

# from pydantic import parse_obj_as               # pydantic version <2.0.0


class PydanticColumn(TypeDecorator):
    """
    PydanticColumn type.
    * for custom column type implementation check https://docs.sqlalchemy.org/en/20/core/custom_types.html
    * Uses sqlalchemy.dialects.postgresql.JSONB if dialects == postgresql else generic sqlalchemy.types.JSON
    * Save:
        - Acceps the pydantic model and converts it to a dict on save.
        - SQLAlchemy engine JSON-encodes the dict to a string.
    * Load:
        - Pulls the string from the database.
        - SQLAlchemy engine JSON-decodes the string to a dict.
        - Uses the dict to create a pydantic model.
    """

    impl = JSON
    cache_ok = True

    def __init__(self, pydantic_type):
        super().__init__()
        if not issubclass(pydantic_type, BaseModel):
            raise ValueError("Column Type Should be Pydantic Class")
        self.pydantic_type = pydantic_type

    def load_dialect_impl(self, dialect):
        # Use JSONB for PostgreSQL and JSON for other databases.
        if dialect.name == "postgresql":
            return dialect.type_descriptor(JSONB())
        else:
            return dialect.type_descriptor(JSON())

    def process_bind_param(self, value, dialect):
        # return value.dict() if value else None   # pydantic <2.0.0
        return value.model_dump() if value else None

    def process_result_value(self, value, dialect):
        # return parse_obj_as(self.pydantic_type, value) if value else None # pydantic < 2.0.0
        return TypeAdapter(self.pydantic_type).validate_python(value)

Building a nested data structure using Pydantic model [pydantic_model.py]

from enum import Enum
from typing import List, Literal, Optional

from pydantic import BaseModel, Field


class ProjecType(str, Enum):
    DEPLOYABLE = "DEPLOYABLE"
    INMEMORY = "INMEMORY"
    SINGLESHOT = "SINGLESHOT"


class ProjecStatus(str, Enum):
    INIT = "INIT"
    DATALOAD = "DEPLOYABLE"
    PREPROCESS = "PREPROCESS"
    POSTPROCESS = "POSTPROCESS"

class Dtypes(str, Enum):
    INTERGER = "integers"
    FLOAT = "float"
    BOOLEAN = "bool"
    CATEGORICAL = "categorical"
    DATE = "date"


class ColumnType(str, Enum):
    FEATURES = "features"
    TARGET = "target"
    INDEX = "index"
    UNIQUEID = "unique-id"
    
class ImputationScheme(str, Enum):
    MEAN = "mean"
    MEDIAN = "median"
    MODE = "mode"
    VALUE = "value"    

class ColumnsDescription(BaseModel):
    name: str
    col_type: ColumnType
    dtype: Dtypes
    mean: Optional[float] = Field(default=None)
    median: Optional[float] = Field(default=None)
    mode: Optional[float] = Field(default=None)
    null_count: Optional[int] = Field(default=None)
    unique_valus: Optional[int] = Field(default=None)
    imputation_scheme: ImputationScheme

class DatasetDescriptor(BaseModel):
    row_count: int = Field(ge=0)
    cloumns_info: List[ColumnsDescription] = Field(default_factory=list)
    duplicate_row_count: int = Field(default=0, ge=0)
    duplicate_columns: List[str] = Field(default_factory=list)
    outlier_count: int = Field(ge=0)
    is_imbalance: Optional[bool] = Field(default=None)

Building a Sqlalchemy Table models using 1[custom_column.py] and 2 [pydantic_model.py]

from typing import Optional

from sqlalchemy import Enum, MetaData, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from src.custom_pydantic_column import PydanticColumn    # custom_column.py
from src.pydantic_model import DatasetDescriptor, ProjecStatus, ProjecType  # pydantic_model.py

meta = MetaData()


class Base(DeclarativeBase):
    metadata = meta


class Projects(Base):
    __tablename__ = "projects"

    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    title: Mapped[str] = mapped_column(String(100), nullable=False)
    descriptions: Mapped[Optional[str]] = mapped_column(
        String(100), nullable=True, default=None
    )
    ptype: Mapped[ProjecType] = mapped_column(Enum(ProjecType), nullable=False)
    status: Mapped[ProjecStatus] = mapped_column(Enum(ProjecStatus), nullable=False)
    dataset_info: Mapped[Optional[DatasetDescriptor]] = mapped_column(
        PydanticColumn(DatasetDescriptor), nullable=True
    )

4. Insert and select Script for testing

import os
import logging
from typing import Dict, Optional

from sqlalchemy import URL, create_engine, insert, select 
from sqlalchemy.orm import Session, sessionmaker

from src.orms import Projects
from src.pydantic_model import (ColumnsDescription, ColumnType,
                                DatasetDescriptor, Dtypes, ImputationScheme,
                                ProjecStatus, ProjecType)

logger = logging.getLogger(__name__)

def insert_project(db: Session, object_in: Dict) -> Optional[Projects]:
    query_stmt = insert(Projects).values(**object_in).returning(Projects)
    status, data = None, None
    try:
        status = db.execute(query_stmt)
        data = status.scalars().one()
    except Exception as err:
        db.rollback()
        logger.error(f"Session id {id(db)}| Error while Insert, Error {err}")
        data = None
    else:
        logger.info(f"Session id {id(db)}| Sucessfully inserted, id {data.id}")
    finally:
        return data
        
def select_project(db: Session, project_id: int) -> Optional[Projects]:
    query_stmt = select(Projects).where(Projects.id == project_id)
    status, data = None, None
    try:
        status = db.execute(query_stmt)
        data = status.scalars().one()
    except Exception as err:
        db.rollback()
        logger.error(f"Session id {id(db)}| Error while Insert, Error {err}")
        data = None
    else:
        logger.info(f"Session id {id(db)}| Sucessfully Selected, id {data.id}")
    finally:
        return data  
        

if __name__ == "__main__":
    dataset = DatasetDescriptor(
        row_count=100,
        duplicate_columns=[],
        duplicate_row_count=0,
        outlier_count=0,
        is_imbalance=False,
        cloumns_info=[
            ColumnsDescription(
                name="job_type",
                col_type=ColumnType.FEATURES,
                dtype=Dtypes.CATEGORICAL,
                mean=None,
                median=None,
                mode=None,
                null_count=0,
                unique_valus=5,
                imputation_scheme=ImputationScheme.MODE,
            )
        ],
    )
    logger.info(f'dataset {dataset.model_dump()}')
    object_in = {
        "title": "Example Project 1",
        "descriptions": "Example Project 1 description",
        "ptype": ProjecType.DEPLOYABLE,
        "status": ProjecStatus.INIT,
        "dataset_info": dataset,
    }
    DB_PATH = os.path.join(os.curdir, "test_sqlite3.db")
    db_url = URL.create(drivername="sqlite", database=DB_PATH)
    logger.info(f"DB URL {db_url}")
    engine = create_engine(db_url, echo=False)
    Session = sessionmaker(engine)
    with Session.begin() as session:
        data = insert_project(db=session, object_in=object_in)
        
        # logger.info(f'data {data}') 
        project = select_project(db=session, project_id=data.id)   
        logger.info(f'dataset_info {project.ptype}, {project.dataset_info}') 
        

Kindly check git_repo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment