Skip to content

Instantly share code, notes, and snippets.

@Ce11an
Last active March 6, 2023 03:07
Show Gist options
  • Save Ce11an/6775b001bf3bbe65d1f06c9d6f1768ba to your computer and use it in GitHub Desktop.
Save Ce11an/6775b001bf3bbe65d1f06c9d6f1768ba to your computer and use it in GitHub Desktop.
SurrealDB WebSocket Client
"""SurrealDB websocket client library."""
import enum
import json
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import pydantic
import websockets
ID = 0
def guid() -> str:
"""Generate a GUID.
Returns:
A GUID.
"""
global ID
ID = (ID + 1) % (2**53 - 1)
return str(ID)
class SurrealException(Exception):
"""Base exception for SurrealDB client library."""
class SurrealAuthenticationException(SurrealException):
"""Exception raised for errors with the SurrealDB authentication."""
class SurrealPermissionException(SurrealException):
"""Exception raised for errors with the SurrealDB permissions."""
class WebSocketState(enum.Enum):
"""Represents the state of a WebSocket connection.
Attributes:
CONNECTING: The WebSocket is connecting.
CONNECTED: The WebSocket is connected.
DISCONNECTED: The WebSocket is disconnected.
"""
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
class Request(pydantic.BaseModel):
"""Represents an RPC request to a Surreal server.
Attributes:
id: The ID of the request.
method: The method of the request.
params: The parameters of the request.
"""
id: str
method: str
params: Optional[Tuple] = None
@pydantic.validator("params", pre=True, always=True)
def validate_params(cls, value): # pylint: disable=no-self-argument
"""Validate the parameters of the request."""
if value is None:
return tuple()
return value
class Config:
"""Represents the configuration of the RPC request."""
allow_mutation = False
class ResponseSuccess(pydantic.BaseModel):
"""Represents a successful RPC response from a Surreal server.
Attributes:
id: The ID of the request.
result: The result of the request.
"""
id: str
result: Any
class Config:
"""Represents the configuration of the RPC request.
Attributes:
allow_mutation: Whether to allow mutation.
"""
allow_mutation = False
class ResponseError(pydantic.BaseModel):
"""Represents an RPC error.
Attributes:
code: The code of the error.
message: The message of the error.
"""
code: int
message: str
class Config:
"""Represents the configuration of the RPC request.
Attributes:
allow_mutation: Whether to allow mutation.
"""
allow_mutation = False
def _validate_response(
response: Union[ResponseSuccess, ResponseError],
exception: Type[Exception] = SurrealException,
) -> ResponseSuccess:
"""Validate the response.
The response is validated by checking if it is an error. If it is an error,
the exception is raised. Otherwise, the response is returned.
Args:
response: The response to validate.
exception: The exception to raise if the response is an error.
Returns:
The original response.
Raises:
SurrealDBException: If the response is an error.
"""
if isinstance(response, ResponseError):
raise exception(response.message)
return response
class Surreal:
"""Surreal is a class that represents a Surreal server.
Attributes:
url: The URL of the Surreal server.
"""
def __init__(self, url: str, token: Optional[str] = None) -> None:
self.url = url
self.token = token
self.client_state = WebSocketState.CONNECTING
self.ws: Optional[websockets.WebSocketClientProtocol] = None # type: ignore
async def __aenter__(self) -> "Surreal":
"""Connect to the Surreal server.
Returns:
The Surreal client.
"""
await self.connect()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[Type[BaseException]] = None,
traceback: Optional[Type[TracebackType]] = None,
) -> None:
"""Disconnects from the Surreal server.
Args:
exc_type: The type of the exception.
exc_value: The value of the exception.
traceback: The traceback of the exception.
"""
await self.disconnect()
async def connect(self) -> None:
"""Connect to the Surreal server."""
self.ws = await websockets.connect(self.url) # type: ignore
self.client_state = WebSocketState.CONNECTED
async def disconnect(self) -> None:
"""Disconnects from the Surreal server."""
await self.ws.close() # type: ignore
self.client_state = WebSocketState.DISCONNECTED
async def ping(self) -> bool:
"""Pings the Surreal server."""
response = await self._send_receive(
Request(
id=guid(),
method="ping",
),
)
success: ResponseSuccess = _validate_response(response)
return success.result
async def use(self, namespace: str, database: str) -> None:
"""Use a namespace and database.
Args:
namespace: The namespace to use.
database: The database to use.
"""
response = await self._send_receive(
Request(id=guid(), method="use", params=(namespace, database)),
)
_validate_response(response)
async def signin(self, auth: Dict[str, Any]) -> str:
"""Signs into the Surreal server.
Args:
auth: The authentication parameters.
"""
response = await self._send_receive(
Request(id=guid(), method="signin", params=(auth,)),
)
success: ResponseSuccess = _validate_response(
response, SurrealAuthenticationException
)
token: str = success.result
self.token = token
return self.token
async def info(self) -> Optional[Dict[str, Any]]:
"""Get the information of the Surreal server.
Returns:
The information of the Surreal server.
"""
response = await self._send_receive(
Request(
id=guid(),
method="info",
),
)
success: ResponseSuccess = _validate_response(response)
return success.result
async def signup(self, auth: Dict[str, Any]) -> None:
"""Signs up to the Surreal server.
Args:
auth: The authentication parameters.
"""
response = await self._send_receive(
Request(id=guid(), method="signup", params=(auth,)),
)
_validate_response(response, SurrealAuthenticationException)
async def invalidate(self) -> None:
"""Invalidates the token."""
response = await self._send_receive(
Request(
id=guid(),
method="invalidate",
),
)
_validate_response(response, SurrealAuthenticationException)
self.token = None
async def authenticate(self) -> None:
"""Authenticate the token."""
response = await self._send_receive(
Request(id=guid(), method="authenticate", params=(self.token,)),
)
_validate_response(response, SurrealAuthenticationException)
async def create(self, thing: str, data: Optional[Dict[str, Any]] = None) -> str:
"""Create a record in the database.
Args:
thing: The table or record ID.
data: The document / record data to insert.
"""
response = await self._send_receive(
Request(
id=guid(),
method="create",
params=(thing,) if data is None else (thing, data),
),
)
success: ResponseSuccess = _validate_response(
response, SurrealPermissionException
)
return success.result
async def delete(self, thing: str) -> None:
"""Delete all records in a table or a specific record from the database.
Args:
thing: The table name or a record ID to select.
"""
response = await self._send_receive(
Request(id=guid(), method="delete", params=(thing,)),
)
_validate_response(response, SurrealPermissionException)
async def update(self, thing, data: Dict[str, Any]) -> None:
"""Update all records in a table or a specific record in the database.
This function replaces the current document / record data with the
specified data.
Args:
thing: The table or record ID.
data: The document / record data to insert.
"""
response = await self._send_receive(
Request(id=guid(), method="update", params=(thing, data)),
)
_validate_response(response, SurrealPermissionException)
async def kill(self) -> None:
"""Kills the Surreal server."""
response = await self._send_receive(
Request(
id=guid(),
method="kill",
),
)
_validate_response(response)
async def select(self, thing: str) -> List[Dict[str, Any]]:
"""Select all records in a table or a specific record from the database.
Args:
thing: The table or record ID to select.
Returns:
The records.
"""
response = await self._send_receive(
Request(id=guid(), method="select", params=(thing,)),
)
success: ResponseSuccess = _validate_response(response)
return success.result
async def modify(self, thing: str, data: Dict[str, Any]) -> None:
"""Modify all records or a specific record in the database.
Applies JSON Patch changes to all records, or a specific record, in the
database. This function patches the current document / record data with
the specified JSON Patch data.
Args:
thing: The table or record ID.
data: The data to modify the record with.
"""
response = await self._send_receive(
Request(id=guid(), method="modify", params=(thing, data)),
)
_validate_response(response, SurrealPermissionException)
async def change(self, thing: str, data: Dict[str, Any]) -> None:
"""Modify all records in a table or a specific record in the database.
This function merges the current document / record data with the
specified data.
Args:
thing: The table name or the specific record ID to change.
data: The document / record data to insert.
"""
response = await self._send_receive(
Request(id=guid(), method="change", params=(thing, data)),
)
_validate_response(response, SurrealPermissionException)
async def query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Query the database.
Args:
query: The query to execute.
params: The query parameters.
Returns:
The records.
"""
response = await self._send_receive(
Request(
id=guid(),
method="query",
params=(query,) if params is None else (query, params),
),
)
success: ResponseSuccess = _validate_response(response)
return success.result
async def live(self, table: str) -> str:
"""Get a live stream of changes to a table.
Args:
table: The table name.
Returns:
The records.
"""
response = await self._send_receive(
Request(id=guid(), method="live", params=(table,)),
)
success: ResponseSuccess = _validate_response(response)
return success.result
async def _send_receive(
self, request: Request
) -> Union[ResponseSuccess, ResponseError]:
"""Send a request to the Surreal server and receive a response.
Args:
request: The request to send.
Returns:
The response from the Surreal server.
Raises:
Exception: If the client is not connected to the Surreal server.
"""
await self._send(request)
return await self._recv()
async def _send(self, request: Request) -> None:
"""Send a request to the Surreal server.
Args:
request: The request to send.
Raises:
Exception: If the client is not connected to the Surreal server.
"""
self._validate_connection()
await self.ws.send(json.dumps(request.dict())) # type: ignore
def _validate_connection(self) -> None:
"""Validate the connection to the Surreal server."""
if self.client_state != WebSocketState.CONNECTED:
raise SurrealException("Not connected to Surreal server.")
async def _recv(self) -> Union[ResponseSuccess, ResponseError]:
"""Receives a response from the Surreal server.
Returns:
The response from the Surreal server.
Raises:
Exception: If the client is not connected to the Surreal server.
Exception: If the response contains an error.
"""
self._validate_connection()
response = json.loads(await self.ws.recv()) # type: ignore
if response.get("error"):
return ResponseError(**response["error"])
return ResponseSuccess(**response)
async def main():
"""Example of how to use the SurrealDB client."""
async with Surreal("ws://127.0.0.1:8000/rpc") as db:
await db.signin({"user": "root", "pass": "root"})
await db.use("test", "test")
await db.create(
"user",
{
"user": "cellan",
"pass": "password",
"DB": "test",
"NS": "test",
"SC": "allusers",
"marketing": True,
"tags": ["python", "javascript"],
},
)
await db.live("user")
return await db.query("SELECT * FROM type::table($tb)", {"tb": "user"})
if __name__ == "__main__":
import asyncio
print(asyncio.run(main()))
@Ce11an
Copy link
Author

Ce11an commented Mar 4, 2023

An example of how to connect to SurrealDB with Websocket in Python.

Packages:

You can set up the table for users by following this tutorial. This Python implementation was greatly influenced by the JavaScript equivalent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment