Skip to content

Instantly share code, notes, and snippets.

@catwell
Last active May 28, 2024 17:57
Show Gist options
  • Save catwell/e47e70b47550ba2fb07d04a41bb8baf0 to your computer and use it in GitHub Desktop.
Save catwell/e47e70b47550ba2fb07d04a41bb8baf0 to your computer and use it in GitHub Desktop.
Avoiding N+1 queries in Strawberry GraphQL with DataLoaders
# DataLoaders version
import logging
from collections import defaultdict
from functools import cached_property
from typing import Any
import strawberry
from quart_cors import cors
from quart_db import QuartDB
from strawberry.dataloader import DataLoader
from strawberry.quart.views import GraphQLView as QuartGraphQLView
from strawberry.types import Info
from quart import Quart, Request, Response
app = Quart("sfdl")
app.logger.setLevel(logging.INFO)
app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl"
app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False
db = QuartDB(app)
cors(app, allow_origin="*", allow_methods=["GET", "POST"])
@strawberry.type
class Song:
id: int
name: str
album_id: int
@strawberry.type
class Album:
id: int
name: str
band_id: int
@strawberry.field
async def songs(self, info: Info) -> list[Song]:
dl = info.context["dataloaders"].songs_for_albums
return await dl.load(self.id)
@strawberry.type
class Band:
id: int
name: str
@strawberry.field
async def albums(self, info: Info) -> list[Album]:
dl = info.context["dataloaders"].albums_for_bands
return await dl.load(self.id)
@strawberry.type
class Query:
@strawberry.field
async def bands(self) -> list[Band]:
query = """
SELECT id, name
FROM bands
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query)
bands = [Band(**row) for row in result]
app.logger.info(f"Got {len(bands)} bands.")
return bands
class DataLoaders:
@staticmethod
async def load_songs_for_albums(keys: list[int]) -> list[list[Song]]:
query = """
SELECT id, name, album_id
FROM songs
WHERE album_id = ANY(:keys)
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"keys": keys})
songs = [Song(**row) for row in result]
app.logger.info(f"Got {len(songs)} songs.")
by_key: defaultdict[int, list[Song]] = defaultdict(list)
for song in songs:
by_key[song.album_id].append(song)
return [by_key[k] for k in keys]
@staticmethod
async def load_albums_for_bands(keys: list[int]) -> list[list[Album]]:
query = """
SELECT id, name, band_id
FROM albums
WHERE band_id = ANY(:keys)
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"keys": keys})
albums = [Album(**row) for row in result]
app.logger.info(f"Got {len(albums)} albums.")
by_key: defaultdict[int, list[Album]] = defaultdict(list)
for album in albums:
by_key[album.band_id].append(album)
return [by_key[k] for k in keys]
@cached_property
def songs_for_albums(self) -> DataLoader[int, list[Song]]:
return DataLoader(self.load_songs_for_albums)
@cached_property
def albums_for_bands(self) -> DataLoader[int, list[Album]]:
return DataLoader(self.load_albums_for_bands)
class GraphQLView(QuartGraphQLView):
async def get_context(self, request: Request, response: Response) -> dict[str, Any]:
return {"request": request, "response": response, "dataloaders": DataLoaders()}
view = GraphQLView.as_view(
"graphql_view",
schema=strawberry.Schema(query=Query),
graphql_ide="graphiql",
)
app.add_url_rule("/", view_func=view)
from quart_db import Connection
async def create_schema(cnx: Connection) -> None:
await cnx.execute(
"""
CREATE TABLE bands (
id bigint PRIMARY KEY,
name text NOT NULL
);
CREATE TABLE albums (
id bigint PRIMARY KEY,
name text NOT NULL,
band_id bigint REFERENCES bands(id)
);
CREATE TABLE songs (
id bigint GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
name text NOT NULL,
album_id bigint REFERENCES albums(id)
);
"""
)
async def populate(cnx: Connection) -> None:
await cnx.execute(
"""
INSERT INTO bands (id, name) VALUES
(1, 'Dark Tranquillity'),
(2, 'Pineapple Thief'),
(3, 'Wintersun');
INSERT INTO albums (id, name, band_id) VALUES
(1, 'Haven', 1),
(2, 'Fiction', 1),
(3, 'Atoma', 1),
(4, 'Time I', 3),
(5, 'Your Wilderness', 2),
(6, 'Versions of the Truth', 2);
INSERT INTO songs (name, album_id) VALUES
('The Wonders at Your Feet', 1),
('Not Built to Last', 1),
('Indifferent Suns', 1),
('At Loss for Words', 1),
('Terminus', 2),
('Inside the Particle Storm', 2),
('Focus Shift', 2),
('Forward Momentum', 3),
('Caves and Embers', 3),
('When Mountains Fall', 4),
('Sons of Winter and Stars', 4),
('Land of Snow and Sorrow', 4),
('Time', 4),
('The Final Thing on My Mind', 5),
('Tear You Up', 5);
"""
)
async def migrate(cnx: Connection) -> None:
await create_schema(cnx)
await populate(cnx)
# Naive version
import logging
from typing import Any
import strawberry
from quart_cors import cors
from quart_db import QuartDB
from strawberry.quart.views import GraphQLView as QuartGraphQLView
from quart import Quart, Request, Response
app = Quart("sfdl")
app.logger.setLevel(logging.INFO)
app.config["QUART_DB_DATABASE_URL"] = "postgresql://postgres@localhost/cwl_sfdl"
app.config["QUART_DB_AUTO_REQUEST_CONNECTION"] = False
db = QuartDB(app)
cors(app, allow_origin="*", allow_methods=["GET", "POST"])
@strawberry.type
class Song:
id: int
name: str
album_id: int
@strawberry.type
class Album:
id: int
name: str
band_id: int
@strawberry.field
async def songs(self) -> list[Song]:
query = """
SELECT id, name, album_id
FROM songs
WHERE album_id = :album_id
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"album_id": self.id})
songs = [Song(**row) for row in result]
app.logger.info(f"Got {len(songs)} songs.")
return songs
@strawberry.type
class Band:
id: int
name: str
@strawberry.field
async def albums(self) -> list[Album]:
query = """
SELECT id, name, band_id
FROM albums
WHERE band_id = :band_id
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query, {"band_id": self.id})
albums = [Album(**row) for row in result]
app.logger.info(f"Got {len(albums)} albums.")
return albums
@strawberry.type
class Query:
@strawberry.field
async def bands(self) -> list[Band]:
query = """
SELECT id, name
FROM bands
"""
async with db.connection() as cnx:
result = await cnx.fetch_all(query)
bands = [Band(**row) for row in result]
app.logger.info(f"Got {len(bands)} bands.")
return bands
class GraphQLView(QuartGraphQLView):
async def get_context(self, request: Request, response: Response) -> dict[str, Any]:
return {"request": request, "response": response}
view = GraphQLView.as_view(
"graphql_view",
schema=strawberry.Schema(query=Query),
graphql_ide="graphiql",
)
app.add_url_rule("/", view_func=view)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment