Skip to content

Instantly share code, notes, and snippets.

@guibeira
Created February 23, 2023 19:10
Show Gist options
  • Save guibeira/547b12eb88f1133ca493093e5d3e70b0 to your computer and use it in GitHub Desktop.
Save guibeira/547b12eb88f1133ca493093e5d3e70b0 to your computer and use it in GitHub Desktop.
Simple solution for create batch and query nested objects
from uuid import uuid4
from collections import defaultdict
import uuid
from pydantic import BaseModel, Field
from pydantic.types import UUID4
from sqlalchemy import Column, MetaData, Table, create_engine, text, ForeignKey
from sqlalchemy.sql import select
from sqlalchemy.dialects import postgresql
DATABASE_NAME = "sqlalchemy_test"
DATABASE_HOST = f"postgresql+psycopg2://postgres:admin@localhost:5432/"
DATABASE_URL = DATABASE_HOST + DATABASE_NAME
## Drop and recreate the tables each run for a quick development loop
psql_engine = create_engine(DATABASE_HOST)
with psql_engine.connect().execution_options(
isolation_level="AUTOCOMMIT"
) as connection:
connection.execute(text(f"DROP DATABASE IF EXISTS {DATABASE_NAME}"))
connection.execute(text(f"CREATE DATABASE {DATABASE_NAME}"))
engine = create_engine(DATABASE_URL)
## Create a few very simple pydantic models
class Child(BaseModel):
id: UUID4 = Field(default_factory=uuid4)
class Parent(BaseModel):
id: UUID4 = Field(default_factory=uuid4)
children: list[Child]
## Write the schema definitions
metadata_obj = MetaData()
parent_table = Table(
"parent",
metadata_obj,
Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
)
child_table = Table(
"child",
metadata_obj,
Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
Column(
"parent_id",
postgresql.UUID(as_uuid=True),
ForeignKey("parent.id"),
nullable=False,
),
Column("created_at", postgresql.TIMESTAMP, server_default=text("now()")),
)
# Other required tables
metadata_obj.create_all(engine)
class ParentSocket:
def create_many(self, objs: list[Parent], limit=100) -> int:
"""Bulk create a list of parent objects and returns the number inserted."""
with engine.begin() as conn:
# Insert parent records
parent_records = [{"id": obj.id} for obj in objs]
parent_inserted = conn.execute(parent_table.insert().values(parent_records))
total_inserted = parent_inserted.rowcount
# Insert child records
child_records = []
for obj in objs:
for child in obj.children:
child_records.append({"id": child.id, "parent_id": obj.id})
# Split the child records into batches of using the limit parameter
child_records_batches = [
child_records[i : i + limit]
for i in range(0, len(child_records), limit)
]
# Insert each batch of child records into the database
for batch in child_records_batches:
children_inserted = conn.execute(child_table.insert().values(batch))
total_inserted += children_inserted.rowcount
return total_inserted
def query(self, id: UUID4 | list[UUID4]) -> list[Parent]:
"""Query either a single id or multiple ids from the Parent table
and return pydantic objects."""
if isinstance(id, uuid.UUID):
id = [id]
query = (
select(parent_table.c.id, child_table.c.id.label("child_id"))
.select_from(parent_table.join(child_table))
.where(parent_table.c.id.in_(id))
.order_by(parent_table.c.id, child_table.c.id)
)
with engine.connect() as conn:
result = conn.execute(query)
# Group the children by parent id
children_by_parent_id = defaultdict(list)
for row in result:
children_by_parent_id[row[0]].append(row[1])
# Map the parent id to parent object with children
parents = []
for parent_id, child_ids in children_by_parent_id.items():
parent_obj = Parent(
id=parent_id,
children=[Child(id=child_id) for child_id in child_ids],
)
parents.append(parent_obj)
return parents
parent_socket = ParentSocket()
# Create several children
parent1 = Parent(children=[Child(), Child()])
parent2 = Parent(children=[Child()])
parent3 = Parent(children=[Child(), Child(), Child()])
assert parent_socket.create_many([parent1, parent2, parent3]) == 9
assert parent_socket.query(id=parent1.id)[0].id == parent1.id
assert len(parent_socket.query(id=[parent1.id, parent2.id])) == 2
# Write any additional tests as desired!
assert len(parent_socket.query(id=parent1.id)[0].children) == 2
assert len(parent_socket.query(id=parent2.id)[0].children) == 1
assert len(parent_socket.query(id=parent3.id)[0].children) == 3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment