Skip to content

Instantly share code, notes, and snippets.

@justinturpin
Created May 8, 2023 20:56
Show Gist options
  • Save justinturpin/52c10eda41a6076567d0e67367429e27 to your computer and use it in GitHub Desktop.
Save justinturpin/52c10eda41a6076567d0e67367429e27 to your computer and use it in GitHub Desktop.
Async Sqlite Wrapper
"""
An Async Sqlite wrapper.
TODO: allow multiple simultaneous readers
TODO: writers should block readers
"""
import asyncio
import sqlite3
import threading
from queue import Queue
from typing import AsyncGenerator
from dataclasses import dataclass, field
@dataclass
class AsyncQuery:
query: str
args: tuple
result: asyncio.Queue = field(default_factory=asyncio.Queue)
class AsyncSqlite:
"""
An Sqlite wrapper for asyncio.
"""
def __init__(self, db_path: str):
self._db_path = db_path
self._query_queue: Queue = Queue()
self._write_queue: Queue = Queue()
# Launch background thread for reading or writing to the database
threading.Thread(target=self._background_read, daemon=True).start()
threading.Thread(target=self._background_write, daemon=True).start()
# Keep track of the main running loop
self._loop = asyncio.get_running_loop()
def _background_read(self) -> None:
"""
Background thread for reading from the database.
"""
conn = sqlite3.connect(self._db_path)
cursor = conn.cursor()
while True:
async_query: AsyncQuery = self._query_queue.get()
cursor.execute(async_query.query, async_query.args)
result_queue = async_query.result
# For each row in the result set, put it into the
# result queue
for row in cursor:
asyncio.run_coroutine_threadsafe(
result_queue.put(row), self._loop
).result()
# Signal that the background thread is done producing
# results for this query
asyncio.run_coroutine_threadsafe(
result_queue.put(None), self._loop
).result()
def _background_write(self) -> None:
"""
Background thread for writing to the database.
"""
conn = sqlite3.connect(self._db_path)
cursor = conn.cursor()
while True:
async_query: AsyncQuery = self._write_queue.get()
cursor.execute(async_query.query, async_query.args)
result_queue = async_query.result
# Put the last inserted row id into the result queue
asyncio.run_coroutine_threadsafe(
result_queue.put(cursor.lastrowid), self._loop
).result()
conn.commit()
async def query(self, query: str, args: tuple = ()) -> AsyncGenerator[tuple, None]:
"""
Execute a read-only query and return an async iterator of the rows.
"""
async_query = AsyncQuery(query, args)
self._query_queue.put(async_query)
while True:
if (row := await async_query.result.get()) is not None:
yield row
else:
break
async def execute(self, query: str, args: tuple = ()) -> int:
"""
Execute a write query and return the last inserted row id.
"""
async_query = AsyncQuery(query, args)
self._write_queue.put(async_query)
return await async_query.result.get()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment