Skip to content

Instantly share code, notes, and snippets.

@dejanceltra
Created August 1, 2023 23:39
Show Gist options
  • Save dejanceltra/ed4778d691448c23acaf7e42cdfc4446 to your computer and use it in GitHub Desktop.
Save dejanceltra/ed4778d691448c23acaf7e42cdfc4446 to your computer and use it in GitHub Desktop.
"""
Simple KV store accessible via TCP.
Supports two operations:
- get(key): fetching value (or None if not found)
- set(key, value): setting value
Server is single-threaded asyncio-based, which simplifies implementation,
and probably multi-threaded would be slower, since it would need to share `_cache` dictionary between threads.
Both sync and async clients are provided.
"""
import asyncio
from enum import Enum
import socket
from typing import Awaitable, Dict, Optional
import aiorwlock
class Commands(Enum):
GET = b'g'
SET = b's'
QUIT = b'q'
NOT_FOUND = b'n'
OK = b'o'
class SyncClient:
def __init__(self):
self._socket: Optional[socket.socket] = None
def connect(self, ip: str, port: int) -> Awaitable[None]:
self._socket = socket.socket()
self._socket.connect((ip, port))
def get(self, key: str) -> Awaitable[Optional[str]]:
if not self._socket:
raise Exception('client not connected')
if not isinstance(key, str):
raise Exception(f'key must be of type str, received: {str(type(key))}')
self._socket.sendall(Commands.GET.value)
length = int.to_bytes(len(key), 4, 'big')
self._socket.sendall(length)
self._socket.sendall(key.encode())
data = self._readexactly(1)
if data == Commands.NOT_FOUND.value:
return None
if data != Commands.OK.value:
raise Exception(f'unknown response: {data}')
length = int.from_bytes(self._readexactly(4), 'big')
return self._readexactly(length)
def set(self, key: str, value: str) -> Awaitable[None]:
if not self._socket:
raise Exception('client not connected')
if not isinstance(key, str):
raise Exception(f'key must be of type str, received: {str(type(key))}')
if not isinstance(value, str):
raise Exception(f'value must be of type str, received: {str(type(value))}')
self._socket.sendall(Commands.SET.value)
length = int.to_bytes(len(key), 4, 'big')
self._socket.sendall(length)
self._socket.sendall(key.encode())
value = value.encode()
length = int.to_bytes(len(value), 4, 'big')
self._socket.sendall(length)
self._socket.sendall(value)
data = self._readexactly(1)
if data != Commands.OK.value:
raise Exception(f'unknown response: {data}')
def close(self) -> Awaitable[None]:
self._socket.sendall(Commands.QUIT.value)
self._socket.close()
def _readexactly(self, length: int) -> bytes:
data = b''
received = 0
while True:
remaining = length - received
part = self._socket.recv(remaining)
data += part
received += len(part)
if received == length:
break
if len(part) == 0:
raise Exception(f'stream reading failed; expected {length} bytes, got {received} bytes')
return data
class AsyncClient:
def __init__(self):
self._reader: Optional[asyncio.StreamReader] = None
self._writer: Optional[asyncio.StreamWriter] = None
async def connect(self, ip: str, port: int) -> Awaitable[None]:
self._reader, self._writer = await asyncio.open_connection(ip, port)
async def get(self, key: str) -> Awaitable[Optional[str]]:
self._writer.write(Commands.GET.value)
length = int.to_bytes(len(key), 4, 'big')
self._writer.write(length)
self._writer.write(key.encode())
await self._writer.drain()
data = await self._reader.read(1)
if data == Commands.NOT_FOUND.value:
return None
if data != Commands.OK.value:
raise Exception(f'unknown response: {data}')
length = int.from_bytes(await self._reader.read(4), 'big')
return await self._reader.readexactly(length)
async def set(self, key: str, value: str) -> Awaitable[None]:
self._writer.write(Commands.SET.value)
length = int.to_bytes(len(key), 4, 'big')
self._writer.write(length)
self._writer.write(key.encode())
length = int.to_bytes(len(value), 4, 'big')
self._writer.write(length)
self._writer.write(value.encode())
await self._writer.drain()
data = await self._reader.read(1)
if data != Commands.OK.value:
raise Exception(f'unknown response: {data}')
async def close(self) -> Awaitable[None]:
self._writer.write(Commands.QUIT.value)
await self._writer.drain()
self._writer.close()
class AsyncServer:
def __init__(self) -> None:
self._cache: Dict[bytes, bytes] = {}
async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> Awaitable[None]:
while True:
command = await reader.readexactly(1)
if command == Commands.QUIT.value:
writer.close()
break
if command == Commands.GET.value:
length = int.from_bytes(await reader.read(4), 'big')
key = await reader.readexactly(length)
value = None
async with self._lock.reader_lock:
if value in self._cache:
value = self._cache[key]
if not value:
writer.write(Commands.NOT_FOUND.value)
continue
writer.write(Commands.OK.value)
writer.write(int.to_bytes(len(value), 4, 'big'))
writer.write(value)
await writer.drain()
continue
if command == Commands.SET.value:
length = int.from_bytes(await reader.read(4), 'big')
key = await reader.readexactly(length)
length = int.from_bytes(await reader.read(4), 'big')
value = await reader.readexactly(length)
async with self._lock.writer_lock:
self._cache[key] = value
writer.write(Commands.OK.value)
await writer.drain()
continue
raise Exception(f'unknown command: {command}')
async def run(self, ip: str, port: int) -> Awaitable[None]:
# not completely sure if this is needed :thinking:
self._lock = aiorwlock.RWLock()
server = await asyncio.start_server(self._handle_client, ip, port)
async with server:
await server.serve_forever()
import asyncio
from python.util.remote_kv_store import AsyncServer
if __name__ == '__main__':
server = AsyncServer()
asyncio.run(server.run('127.0.0.1', 15555))
import asyncio
from enum import Enum
from python.util.remote_kv_store import AsyncClient, SyncClient
async def tcp_echo_client(message):
client = AsyncClient()
await client.connect('127.0.0.1', 15555)
print(await client.get('omg'))
print(await client.set('omg', 'second'))
print(await client.get('omg'))
await client.close()
asyncio.run(tcp_echo_client('Hello World!'))
client = SyncClient()
client.connect('127.0.0.1', 15555)
print(client.get('omg'))
print(client.set('omg', 'second'))
print(client.get('omg'))
client.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment