Created
May 8, 2023 20:56
-
-
Save justinturpin/52c10eda41a6076567d0e67367429e27 to your computer and use it in GitHub Desktop.
Async Sqlite Wrapper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An Async Sqlite wrapper. | |
TODO: allow multiple simultaneous readers | |
TODO: writers should block readers | |
""" | |
import asyncio | |
import sqlite3 | |
import threading | |
from queue import Queue | |
from typing import AsyncGenerator | |
from dataclasses import dataclass, field | |
@dataclass | |
class AsyncQuery: | |
query: str | |
args: tuple | |
result: asyncio.Queue = field(default_factory=asyncio.Queue) | |
class AsyncSqlite: | |
""" | |
An Sqlite wrapper for asyncio. | |
""" | |
def __init__(self, db_path: str): | |
self._db_path = db_path | |
self._query_queue: Queue = Queue() | |
self._write_queue: Queue = Queue() | |
# Launch background thread for reading or writing to the database | |
threading.Thread(target=self._background_read, daemon=True).start() | |
threading.Thread(target=self._background_write, daemon=True).start() | |
# Keep track of the main running loop | |
self._loop = asyncio.get_running_loop() | |
def _background_read(self) -> None: | |
""" | |
Background thread for reading from the database. | |
""" | |
conn = sqlite3.connect(self._db_path) | |
cursor = conn.cursor() | |
while True: | |
async_query: AsyncQuery = self._query_queue.get() | |
cursor.execute(async_query.query, async_query.args) | |
result_queue = async_query.result | |
# For each row in the result set, put it into the | |
# result queue | |
for row in cursor: | |
asyncio.run_coroutine_threadsafe( | |
result_queue.put(row), self._loop | |
).result() | |
# Signal that the background thread is done producing | |
# results for this query | |
asyncio.run_coroutine_threadsafe( | |
result_queue.put(None), self._loop | |
).result() | |
def _background_write(self) -> None: | |
""" | |
Background thread for writing to the database. | |
""" | |
conn = sqlite3.connect(self._db_path) | |
cursor = conn.cursor() | |
while True: | |
async_query: AsyncQuery = self._write_queue.get() | |
cursor.execute(async_query.query, async_query.args) | |
result_queue = async_query.result | |
# Put the last inserted row id into the result queue | |
asyncio.run_coroutine_threadsafe( | |
result_queue.put(cursor.lastrowid), self._loop | |
).result() | |
conn.commit() | |
async def query(self, query: str, args: tuple = ()) -> AsyncGenerator[tuple, None]: | |
""" | |
Execute a read-only query and return an async iterator of the rows. | |
""" | |
async_query = AsyncQuery(query, args) | |
self._query_queue.put(async_query) | |
while True: | |
if (row := await async_query.result.get()) is not None: | |
yield row | |
else: | |
break | |
async def execute(self, query: str, args: tuple = ()) -> int: | |
""" | |
Execute a write query and return the last inserted row id. | |
""" | |
async_query = AsyncQuery(query, args) | |
self._write_queue.put(async_query) | |
return await async_query.result.get() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment