Skip to content

Instantly share code, notes, and snippets.

@madeinoz67
Last active June 6, 2021 11:58
Show Gist options
  • Save madeinoz67/df67e01f9421f049e646c2820a8ed8f2 to your computer and use it in GitHub Desktop.
Save madeinoz67/df67e01f9421f049e646c2820a8ed8f2 to your computer and use it in GitHub Desktop.
conftest.py for pytest when using async sqlalchemy v1.4 database testing using sqlite with aiosqlite driver. Alembic for database setup and tear-downs. this is what I finally got working when testing services for fastapi
import os
import warnings
from contextvars import ContextVar
from typing import AsyncGenerator
from unittest import mock
import alembic
import pytest
from alembic.config import Config
# from example.routers.utils.db import get_db
from fastapi import FastAPI
from httpx import AsyncClient # noqa:
from pytest_factoryboy import register
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from tests.factories import PartCreateSchemaFactory
# Can be overridden by environment variable for testing in CI against other
# database engines
SQLALCHEMY_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL", "sqlite+aiosqlite:///./tests/files/test.db"
)
# Register factories
register(PartCreateSchemaFactory)
engine = create_async_engine(
SQLALCHEMY_DATABASE_URL, echo=True, connect_args={"check_same_thread": False}
)
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
# Apply migrations at beginning and end of testing session
@pytest.fixture(scope="function")
def apply_migrations():
with mock.patch.dict(
os.environ, {"DATABASE_URL": SQLALCHEMY_DATABASE_URL}, clear=True
):
warnings.filterwarnings("ignore", category=DeprecationWarning)
config = Config("alembic.ini")
alembic.command.upgrade(config, "head")
yield
alembic.command.downgrade(config, "base")
@pytest.fixture()
def app(apply_migrations: None) -> FastAPI:
"""
Create a fresh database on each test case.
"""
from app.main import get_application
return get_application()
@pytest.fixture(scope="function")
@pytest.mark.asyncio
async def db_session(apply_migrations) -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
try:
yield session
finally:
await session.rollback()
await session.close()
import factory
from faker import Faker
from faker_optional import OptionalProvider
from app.models.part import PartModel
from app.schema.part import PartPublicSchema # noqa:
from app.schema.part import PartUpdateSchema # noqa:
from app.schema.part import PartCreateSchema
Faker.seed(0) # reproducible random
fake = Faker()
fake.add_provider(OptionalProvider)
class PartModelFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = PartModel
name = fake.word()
description = fake.text()
footprint = fake.bothify(text="SO?-#", letters="PT")
manufacturer = fake.company()
mpn = fake.bothify(text="???-####-###-???")
notes = fake.text()
class PartCreateSchemaFactory(factory.Factory):
class Meta:
model = PartCreateSchema
name = fake.word()
description = fake.text()
footprint = fake.bothify(text="SO?-#", letters="PT")
manufacturer = fake.company()
mpn = fake.bothify(text="???-####-###-???")
notes = fake.text()
import datetime
from typing import List
from loguru import logger
from nanoid import generate
from sqlalchemy import delete, func, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.core.config import settings
from app.db.session import db_session_context
from app.models.part import PartModel
from app.schema.datatable import DataTableRequest, PartDataTableResponse
from app.schema.part import PartCreateSchema, PartUpdateSchema
async def create_part(db: AsyncSession, details: PartCreateSchema) -> PartModel:
"""Creates a new Part and saves to db
Args:
details (PartCreate): details of Part to create
Returns:
part (PartPublic): the newly created Part
"""
part = PartModel()
part.id = generate(_alphabet, _size)
if details.name is None:
details.name = part.id
part.name = details.name
part.description = details.description
part.notes = details.notes
part.footprint = details.footprint
part.manufacturer = details.manufacturer
part.mpn = details.mpn
# TODO: use contextvar for db dependencies when https://github.com/pytest-dev/pytest-asyncio/pull/161 is merged
# db_session: AsyncSession = db_session_context.get()
db_session: AsyncSession = db
async with db_session as session:
session.add(part)
try:
await session.commit()
except IntegrityError as ex:
await session.rollback()
logger.error("Part ID already exists in the database", ex)
return part
import datetime
import sqlalchemy as sa
from app.models.modelbase import SqlAlchemyBase
class PartModel(SqlAlchemyBase):
__tablename__ = "part"
id: str = sa.Column(sa.String, primary_key=True)
name: str = sa.Column(sa.String, index=True)
description: str = sa.Column(sa.String, nullable=True, index=True)
created_at: datetime.datetime = sa.Column(
sa.DateTime, default=datetime.datetime.now, index=True
)
updated_at: datetime.datetime = sa.Column(
sa.DateTime, default=datetime.datetime.now, index=True
)
notes: str = sa.Column(sa.String, nullable=True)
footprint: str = sa.Column(sa.String, nullable=True, index=True)
manufacturer: str = sa.Column(sa.String, nullable=True, index=True)
mpn: str = sa.Column(
sa.String, nullable=True, index=True
) # Manufacturers Part Number
# # stock relationship
# stock: List[Stock] = orm.relationship(
# Stock, order_by=[Stock.last_updated], back_populates="part"
# )
def __repr__(self):
return "<Part {}>".format(self.id)
import datetime
from typing import Optional
from pydantic import HttpUrl
from app.schema.core import CoreSchema, IDSchemaMixin
class PartBaseSchema(CoreSchema):
name: Optional[str] = None
description: Optional[str] = None
notes: Optional[str] = None
footprint: Optional[str] = None
manufacturer: Optional[str] = None
mpn: Optional[str] = None
created_at: Optional[datetime.datetime] = None
updated_at: Optional[datetime.datetime] = None
class Config:
orm_mode = True
class PartCreateSchema(PartBaseSchema):
name: str
class PartUpdateSchema(PartBaseSchema):
pass
class PartInDBSchema(IDSchemaMixin, PartBaseSchema):
name: str
class PartPublicSchema(IDSchemaMixin, PartBaseSchema):
id: str
href: Optional[HttpUrl] = None
def test_new_part_create_schema(capsys):
part = factory.build(dict, FACTORY_CLASS=PartCreateSchemaFactory)
print((f"{part}"))
captured = capsys.readouterr()
assert part is not None
assert "'name':" in captured.out
import pytest
from loguru import logger # noqa:
from sqlalchemy.ext.asyncio import AsyncSession # noqa:
from app.models.part import PartModel
from app.services import part_service
class TestPartService:
@pytest.mark.asyncio
async def test_add_single_part(
self,
part_create_schema_factory,
db_session,
):
obj_in = part_create_schema_factory
item: PartModel = await part_service.create_part(db_session, obj_in) # noqa:
assert item.name == obj_in.name
assert item.id is not None
assert len(item.id) == 26
@pytest.mark.asyncio
async def test_part_count_with_empty_db(self, db_session):
assert await part_service.get_part_count(db_session) == 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment