Skip to content

Instantly share code, notes, and snippets.

@Anton-2
Created February 26, 2025 15:42
SurrealDB async POC
"""
A basic async connection to a SurrealDB instance.
"""
import asyncio
import uuid
from asyncio import Queue, Task, Future
from typing import Optional, Any, Dict, Union, List, AsyncGenerator
from uuid import UUID
import websockets
from surrealdb.connections.async_template import AsyncTemplate
from surrealdb.connections.url import Url
from surrealdb.connections.utils_mixin import UtilsMixin
from surrealdb.data.cbor import decode
from surrealdb.data.types.record_id import RecordID
from surrealdb.data.types.table import Table
from surrealdb.request_message.message import RequestMessage
from surrealdb.request_message.methods import RequestMethod
class AsyncWsSurrealConnection(AsyncTemplate, UtilsMixin):
"""
A single async connection to a SurrealDB instance. To be used once and discarded.
Attributes:
url: The URL of the database to process queries for.
user: The username to login on.
password: The password to login on.
namespace: The namespace that the connection will stick to.
database: The database that the connection will stick to.
id: The ID of the connection.
"""
def __init__(
self,
url: str,
) -> None:
"""
The constructor for the AsyncSurrealConnection class.
:param url: The URL of the database to process queries for.
"""
self.url: Url = Url(url)
self.raw_url: str = f"{self.url.raw_url}/rpc"
self.host: Optional[str] = self.url.hostname
self.port: Optional[int] = self.url.port
self.id: str = str(uuid.uuid4())
self.token: Optional[str] = None
self.socket = None
self.qry:dict[str, Future] = {}
self.waiter_task:Task[None]|None = None
async def _recv_task(self):
assert self.socket
async for data in self.socket:
response = decode(data)
if fut := self.qry.get(response["id"]):
fut.set_result(response)
async def _send(
self, message: RequestMessage, process: str, bypass: bool = False
) -> dict:
await self.connect()
assert (
self.socket is not None
) # will always not be None as the self.connect ensures there's a connection
query_id = str(uuid.uuid4())
# setup future to wait for response
fut = self.loop.create_future()
self.qry[query_id] = fut
try:
# correlate mesage to query, send and forget it
message.id = query_id
await self.socket.send(message.WS_CBOR_DESCRIPTOR)
del message
# wait for response
response = await fut
finally:
del self.qry[query_id]
if bypass is False:
self.check_response_for_error(response, process)
return response
async def connect(self, url: Optional[str] = None) -> None:
# overwrite params if passed in
if url is not None:
self.url = Url(url)
self.raw_url = f"{self.url.raw_url}/rpc"
self.host = self.url.hostname
self.port = self.url.port
if self.socket is None:
self.socket = await websockets.connect(
self.raw_url,
max_size=None,
subprotocols=[websockets.Subprotocol("cbor")],
)
async def authenticate(self, token: str) -> dict:
message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token)
return await self._send(message, "authenticating")
async def invalidate(self) -> None:
message = RequestMessage(self.id, RequestMethod.INVALIDATE)
await self._send(message, "invalidating")
self.token = None
async def signup(self, vars: Dict) -> str:
message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars)
response = await self._send(message, "signup")
self.check_response_for_result(response, "signup")
return response["result"]
async def signin(self, vars: Dict[str, Any]) -> str:
message = RequestMessage(
self.id,
RequestMethod.SIGN_IN,
username=vars.get("username"),
password=vars.get("password"),
access=vars.get("access"),
database=vars.get("database"),
namespace=vars.get("namespace"),
variables=vars.get("variables"),
)
response = await self._send(message, "signing in")
self.check_response_for_result(response, "signing in")
self.token = response["result"]
return response["result"]
async def info(self) -> Optional[dict]:
message = RequestMessage(self.id, RequestMethod.INFO)
outcome = await self._send(message, "getting database information")
self.check_response_for_result(outcome, "getting database information")
return outcome["result"]
async def use(self, namespace: str, database: str) -> None:
message = RequestMessage(
self.id,
RequestMethod.USE,
namespace=namespace,
database=database,
)
await self._send(message, "use")
async def query(self, query: str, params: Optional[dict] = None) -> dict:
if params is None:
params = {}
message = RequestMessage(
self.id,
RequestMethod.QUERY,
query=query,
params=params,
)
response = await self._send(message, "query")
self.check_response_for_result(response, "query")
return response["result"][0]["result"]
async def query_raw(self, query: str, params: Optional[dict] = None) -> dict:
if params is None:
params = {}
message = RequestMessage(
self.id,
RequestMethod.QUERY,
query=query,
params=params,
)
response = await self._send(message, "query", bypass=True)
return response
async def version(self) -> str:
message = RequestMessage(self.id, RequestMethod.VERSION)
response = await self._send(message, "getting database version")
self.check_response_for_result(response, "getting database version")
return response["result"]
async def let(self, key: str, value: Any) -> None:
message = RequestMessage(self.id, RequestMethod.LET, key=key, value=value)
await self._send(message, "letting")
async def unset(self, key: str) -> None:
message = RequestMessage(self.id, RequestMethod.UNSET, params=[key])
await self._send(message, "unsetting")
async def select(
self, thing: Union[str, RecordID, Table]
) -> Union[List[dict], dict]:
message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing])
response = await self._send(message, "select")
self.check_response_for_result(response, "select")
return response["result"]
async def create(
self,
thing: Union[str, RecordID, Table],
data: Optional[Union[Union[List[dict], dict], dict]] = None,
) -> Union[List[dict], dict]:
if isinstance(thing, str):
if ":" in thing:
buffer = thing.split(":")
thing = RecordID(table_name=buffer[0], identifier=buffer[1])
message = RequestMessage(
self.id, RequestMethod.CREATE, collection=thing, data=data
)
response = await self._send(message, "create")
self.check_response_for_result(response, "create")
return response["result"]
async def update(
self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None
) -> Union[List[dict], dict]:
message = RequestMessage(
self.id, RequestMethod.UPDATE, record_id=thing, data=data
)
response = await self._send(message, "update")
self.check_response_for_result(response, "update")
return response["result"]
async def merge(
self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None
) -> Union[List[dict], dict]:
message = RequestMessage(
self.id, RequestMethod.MERGE, record_id=thing, data=data
)
response = await self._send(message, "merge")
self.check_response_for_result(response, "merge")
return response["result"]
async def patch(
self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None
) -> Union[List[dict], dict]:
message = RequestMessage(
self.id, RequestMethod.PATCH, collection=thing, params=data
)
response = await self._send(message, "patch")
self.check_response_for_result(response, "patch")
return response["result"]
async def delete(
self, thing: Union[str, RecordID, Table]
) -> Union[List[dict], dict]:
message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing)
response = await self._send(message, "delete")
self.check_response_for_result(response, "delete")
return response["result"]
async def insert(
self, table: Union[str, Table], data: Union[List[dict], dict]
) -> Union[List[dict], dict]:
message = RequestMessage(
self.id, RequestMethod.INSERT, collection=table, params=data
)
response = await self._send(message, "insert")
self.check_response_for_result(response, "insert")
return response["result"]
async def insert_relation(
self, table: Union[str, Table], data: Union[List[dict], dict]
) -> Union[List[dict], dict]:
message = RequestMessage(
self.id, RequestMethod.INSERT_RELATION, table=table, params=data
)
response = await self._send(message, "insert_relation")
self.check_response_for_result(response, "insert_relation")
return response["result"]
async def live(self, table: Union[str, Table], diff: bool = False) -> UUID:
message = RequestMessage(
self.id,
RequestMethod.LIVE,
table=table,
)
response = await self._send(message, "live")
self.check_response_for_result(response, "live")
return response["result"]
async def subscribe_live(
self, query_uuid: Union[str, UUID]
) -> AsyncGenerator[dict, None]:
result_queue = Queue()
async def listen_live():
"""
Listen for live updates from the WebSocket and put them into the queue.
"""
try:
while True:
response = decode(await self.socket.recv())
if response.get("result", {}).get("id") == query_uuid:
await result_queue.put(response["result"]["result"])
except Exception as e:
print("Error in live subscription:", e)
await result_queue.put({"error": str(e)})
asyncio.create_task(listen_live())
while True:
result = await result_queue.get()
if "error" in result:
raise Exception(f"Error in live subscription: {result['error']}")
yield result
async def kill(self, query_uuid: Union[str, UUID]) -> None:
message = RequestMessage(self.id, RequestMethod.KILL, uuid=query_uuid)
await self._send(message, "kill")
async def upsert(
self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None
) -> Union[List[dict], dict]:
message = RequestMessage(
self.id, RequestMethod.UPSERT, record_id=thing, data=data
)
response = await self._send(message, "upsert")
self.check_response_for_result(response, "upsert")
return response["result"]
async def close(self):
if self.waiter_task:
self.waiter_task.cancel()
try:
await self.waiter_task
except asyncio.CancelledError:
print("waiter_task is cancelled now")
await self.socket.close()
async def __aenter__(self) -> "AsyncWsSurrealConnection":
"""
Asynchronous context manager entry.
Initializes a websocket connection and returns the connection instance.
"""
self.socket = await websockets.connect(
self.raw_url, max_size=None, subprotocols=[websockets.Subprotocol("cbor")]
)
self.loop = asyncio.get_running_loop()
self.waiter_task = asyncio.create_task(self._recv_task())
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""
Asynchronous context manager exit.
Closes the websocket connection upon exiting the context.
"""
if self.waiter_task:
self.waiter_task.cancel()
try:
await self.waiter_task
except asyncio.CancelledError:
pass
if self.socket is not None:
await self.socket.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment