Skip to content

Instantly share code, notes, and snippets.

@rafaelhenrique
Created April 17, 2021 02:06
Show Gist options
  • Save rafaelhenrique/0ca2c49f274ba0291cd2e43ff9b2241a to your computer and use it in GitHub Desktop.
Save rafaelhenrique/0ca2c49f274ba0291cd2e43ff9b2241a to your computer and use it in GitHub Desktop.
Sqlalchemy + Asyncpg + Pytest = <3
# reference: https://github.com/sqlalchemy/sqlalchemy/issues/5626
import pytest
import sqlalchemy as sa
from sqlalchemy import orm
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
Base = orm.declarative_base()
class Country(Base):
__tablename__ = "country"
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Text)
@pytest.fixture(scope="session")
def engine():
engine = create_async_engine(
"postgresql://scott:tiger@127.0.0.1:5432/test"
)
yield engine
engine.sync_engine.dispose()
@pytest.fixture()
async def create(engine):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def session(engine, create):
async with AsyncSession(engine) as session:
yield session
@pytest.mark.asyncio
async def test_one(session):
c = Country(name="foo")
session.add(c)
await session.commit()
assert len((await session.execute(sa.select(Country))).scalars().all()) == 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment