Last active
June 6, 2021 11:58
-
-
Save madeinoz67/df67e01f9421f049e646c2820a8ed8f2 to your computer and use it in GitHub Desktop.
conftest.py for pytest when using async sqlalchemy v1.4 database testing using sqlite with aiosqlite driver. Alembic for database setup and tear-downs. this is what I finally got working when testing services for fastapi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import warnings | |
from contextvars import ContextVar | |
from typing import AsyncGenerator | |
from unittest import mock | |
import alembic | |
import pytest | |
from alembic.config import Config | |
# from example.routers.utils.db import get_db | |
from fastapi import FastAPI | |
from httpx import AsyncClient # noqa: | |
from pytest_factoryboy import register | |
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine | |
from sqlalchemy.orm import sessionmaker | |
from tests.factories import PartCreateSchemaFactory | |
# Can be overridden by environment variable for testing in CI against other | |
# database engines | |
SQLALCHEMY_DATABASE_URL = os.getenv( | |
"TEST_DATABASE_URL", "sqlite+aiosqlite:///./tests/files/test.db" | |
) | |
# Register factories | |
register(PartCreateSchemaFactory) | |
engine = create_async_engine( | |
SQLALCHEMY_DATABASE_URL, echo=True, connect_args={"check_same_thread": False} | |
) | |
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) | |
# Apply migrations at beginning and end of testing session | |
@pytest.fixture(scope="function") | |
def apply_migrations(): | |
with mock.patch.dict( | |
os.environ, {"DATABASE_URL": SQLALCHEMY_DATABASE_URL}, clear=True | |
): | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
config = Config("alembic.ini") | |
alembic.command.upgrade(config, "head") | |
yield | |
alembic.command.downgrade(config, "base") | |
@pytest.fixture() | |
def app(apply_migrations: None) -> FastAPI: | |
""" | |
Create a fresh database on each test case. | |
""" | |
from app.main import get_application | |
return get_application() | |
@pytest.fixture(scope="function") | |
@pytest.mark.asyncio | |
async def db_session(apply_migrations) -> AsyncGenerator[AsyncSession, None]: | |
async with async_session() as session: | |
try: | |
yield session | |
finally: | |
await session.rollback() | |
await session.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import factory | |
from faker import Faker | |
from faker_optional import OptionalProvider | |
from app.models.part import PartModel | |
from app.schema.part import PartPublicSchema # noqa: | |
from app.schema.part import PartUpdateSchema # noqa: | |
from app.schema.part import PartCreateSchema | |
Faker.seed(0) # reproducible random | |
fake = Faker() | |
fake.add_provider(OptionalProvider) | |
class PartModelFactory(factory.alchemy.SQLAlchemyModelFactory): | |
class Meta: | |
model = PartModel | |
name = fake.word() | |
description = fake.text() | |
footprint = fake.bothify(text="SO?-#", letters="PT") | |
manufacturer = fake.company() | |
mpn = fake.bothify(text="???-####-###-???") | |
notes = fake.text() | |
class PartCreateSchemaFactory(factory.Factory): | |
class Meta: | |
model = PartCreateSchema | |
name = fake.word() | |
description = fake.text() | |
footprint = fake.bothify(text="SO?-#", letters="PT") | |
manufacturer = fake.company() | |
mpn = fake.bothify(text="???-####-###-???") | |
notes = fake.text() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from typing import List | |
from loguru import logger | |
from nanoid import generate | |
from sqlalchemy import delete, func, update | |
from sqlalchemy.exc import IntegrityError | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from sqlalchemy.future import select | |
from app.core.config import settings | |
from app.db.session import db_session_context | |
from app.models.part import PartModel | |
from app.schema.datatable import DataTableRequest, PartDataTableResponse | |
from app.schema.part import PartCreateSchema, PartUpdateSchema | |
async def create_part(db: AsyncSession, details: PartCreateSchema) -> PartModel: | |
"""Creates a new Part and saves to db | |
Args: | |
details (PartCreate): details of Part to create | |
Returns: | |
part (PartPublic): the newly created Part | |
""" | |
part = PartModel() | |
part.id = generate(_alphabet, _size) | |
if details.name is None: | |
details.name = part.id | |
part.name = details.name | |
part.description = details.description | |
part.notes = details.notes | |
part.footprint = details.footprint | |
part.manufacturer = details.manufacturer | |
part.mpn = details.mpn | |
# TODO: use contextvar for db dependencies when https://github.com/pytest-dev/pytest-asyncio/pull/161 is merged | |
# db_session: AsyncSession = db_session_context.get() | |
db_session: AsyncSession = db | |
async with db_session as session: | |
session.add(part) | |
try: | |
await session.commit() | |
except IntegrityError as ex: | |
await session.rollback() | |
logger.error("Part ID already exists in the database", ex) | |
return part |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import sqlalchemy as sa | |
from app.models.modelbase import SqlAlchemyBase | |
class PartModel(SqlAlchemyBase): | |
__tablename__ = "part" | |
id: str = sa.Column(sa.String, primary_key=True) | |
name: str = sa.Column(sa.String, index=True) | |
description: str = sa.Column(sa.String, nullable=True, index=True) | |
created_at: datetime.datetime = sa.Column( | |
sa.DateTime, default=datetime.datetime.now, index=True | |
) | |
updated_at: datetime.datetime = sa.Column( | |
sa.DateTime, default=datetime.datetime.now, index=True | |
) | |
notes: str = sa.Column(sa.String, nullable=True) | |
footprint: str = sa.Column(sa.String, nullable=True, index=True) | |
manufacturer: str = sa.Column(sa.String, nullable=True, index=True) | |
mpn: str = sa.Column( | |
sa.String, nullable=True, index=True | |
) # Manufacturers Part Number | |
# # stock relationship | |
# stock: List[Stock] = orm.relationship( | |
# Stock, order_by=[Stock.last_updated], back_populates="part" | |
# ) | |
def __repr__(self): | |
return "<Part {}>".format(self.id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from typing import Optional | |
from pydantic import HttpUrl | |
from app.schema.core import CoreSchema, IDSchemaMixin | |
class PartBaseSchema(CoreSchema): | |
name: Optional[str] = None | |
description: Optional[str] = None | |
notes: Optional[str] = None | |
footprint: Optional[str] = None | |
manufacturer: Optional[str] = None | |
mpn: Optional[str] = None | |
created_at: Optional[datetime.datetime] = None | |
updated_at: Optional[datetime.datetime] = None | |
class Config: | |
orm_mode = True | |
class PartCreateSchema(PartBaseSchema): | |
name: str | |
class PartUpdateSchema(PartBaseSchema): | |
pass | |
class PartInDBSchema(IDSchemaMixin, PartBaseSchema): | |
name: str | |
class PartPublicSchema(IDSchemaMixin, PartBaseSchema): | |
id: str | |
href: Optional[HttpUrl] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_new_part_create_schema(capsys): | |
part = factory.build(dict, FACTORY_CLASS=PartCreateSchemaFactory) | |
print((f"{part}")) | |
captured = capsys.readouterr() | |
assert part is not None | |
assert "'name':" in captured.out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from loguru import logger # noqa: | |
from sqlalchemy.ext.asyncio import AsyncSession # noqa: | |
from app.models.part import PartModel | |
from app.services import part_service | |
class TestPartService: | |
@pytest.mark.asyncio | |
async def test_add_single_part( | |
self, | |
part_create_schema_factory, | |
db_session, | |
): | |
obj_in = part_create_schema_factory | |
item: PartModel = await part_service.create_part(db_session, obj_in) # noqa: | |
assert item.name == obj_in.name | |
assert item.id is not None | |
assert len(item.id) == 26 | |
@pytest.mark.asyncio | |
async def test_part_count_with_empty_db(self, db_session): | |
assert await part_service.get_part_count(db_session) == 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment