Skip to content

Instantly share code, notes, and snippets.

@e-kondr01
Last active July 8, 2024 22:25
Show Gist options
  • Save e-kondr01/969ae24f2e2f31bd52a81fa5a1fe0f96 to your computer and use it in GitHub Desktop.
Save e-kondr01/969ae24f2e2f31bd52a81fa5a1fe0f96 to your computer and use it in GitHub Desktop.
Pytest + FastAPI + Async SQLAlchemy

Run API tests with Pytest, FastAPI and Async SQLAlchemy. Changes made in test functions are not persisted to DB, even if await session.commit() is called. This allows tests to be independent, able to run in parallel or in a shuffled order without affecting the result.

This snippet does not include creation of DB tables, as I use Alembic for migration management and advise you to do the same (even in tests).

fastapi
pytest
sqlalchemy>=2.0.0
httpx
from typing import AsyncGenerator
from uuid import UUID, uuid4
import pytest
from httpx import AsyncClient
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
# Importing fastapi.Depends that is used to retrieve SQLAlchemy's session
from app.api.deps import get_async_session
# Importing main FastAPI instance
from app.main import app
# To run async tests
pytestmark = pytest.mark.anyio
# Supply connection string
engine = create_async_engine("postgresql+psycopg2://...")
# SQLAlchemy model for demo purposes
class Profile(DeclarativeBase):
id: Mapped[UUID] = mapped_column(
primary_key=True,
default=uuid4,
server_default=func.gen_random_uuid(),
)
name: Mapped[str]
# Required per https://anyio.readthedocs.io/en/stable/testing.html#using-async-fixtures-with-higher-scopes
@pytest.fixture(scope="session")
def anyio_backend():
return "asyncio"
@pytest.fixture(scope="session")
async def connection(anyio_backend) -> AsyncGenerator[AsyncConnection, None]:
async with engine.connect() as connection:
yield connection
@pytest.fixture()
async def transaction(
connection: AsyncConnection,
) -> AsyncGenerator[AsyncTransaction, None]:
async with connection.begin() as transaction:
yield transaction
# Use this fixture to get SQLAlchemy's AsyncSession.
# All changes that occur in a test function are rolled back
# after function exits, even if session.commit() is called
# in inner functions
@pytest.fixture()
async def session(
connection: AsyncConnection, transaction: AsyncTransaction
) -> AsyncGenerator[AsyncSession, None]:
async_session = AsyncSession(
bind=connection,
join_transaction_mode="create_savepoint",
)
yield async_session
await transaction.rollback()
# Tests showing rollbacks between functions when using SQLAlchemy's session
async def test_create_profile(session: AsyncSession):
existing_profiles = (await session.execute(select(Profile))).scalars().all()
assert len(existing_profiles) == 0
test_name = "test"
session.add(Profile(name=test_name))
await session.commit()
existing_profiles = (await session.execute(select(Profile))).scalars().all()
assert len(existing_profiles) == 1
assert existing_profiles[0].name == test_name
async def test_rollbacks_between_functions(session: AsyncSession):
existing_profiles = (await session.execute(select(Profile))).scalars().all()
assert len(existing_profiles) == 0
# Use this fixture to get HTTPX's client to test API.
# All changes that occur in a test function are rolled back
# after function exits, even if session.commit() is called
# in FastAPI's application endpoints
@pytest.fixture()
async def client(
connection: AsyncConnection, transaction: AsyncTransaction
) -> AsyncGenerator[AsyncClient, None]:
async def override_get_async_session() -> AsyncGenerator[AsyncSession, None]:
async_session = AsyncSession(
bind=connection,
join_transaction_mode="create_savepoint",
)
async with async_session:
yield async_session
# Here you have to override the dependency that is used in FastAPI's
# endpoints to get SQLAlchemy's AsyncSession. In my case, it is
# get_async_session
app.dependency_overrides[get_async_session] = override_get_async_session
yield AsyncClient(app=app, base_url="http://test")
del app.dependency_overrides[get_async_session]
await transaction.rollback()
# Tests showing rollbacks between functions when using API client
async def test_api_create_profile(client: AsyncClient):
test_name = "test"
async with client as ac:
response = await ac.post(
"/api/profiles",
json={"name": test_name},
)
created_profile_id = response.json()["id"]
response = await ac.get(
"/api/profiles",
)
assert response.status_code == 200
assert len(response.json()) == 1
response = await ac.get(
f"/api/profiles/{created_profile_id}",
)
assert response.status_code == 200
assert response.json()["id"] == created_profile_id
assert response.json()["name"] == test_name
async def test_client_rollbacks(client: AsyncClient):
async with client as ac:
response = await ac.get(
"/api/profiles",
)
assert len(response.json()) == 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment