Last active
January 20, 2023 04:31
-
-
Save jauhararifin/13dcf59c963c86c3059b9a5d5e747357 to your computer and use it in GitHub Desktop.
Super Simple Redis Clone In Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import socket | |
import os | |
import logging | |
import dataclasses | |
import asyncio | |
from typing import Callable, Dict, List, Any, Union | |
@dataclasses.dataclass | |
class Error: | |
code: str | |
message: str | |
@dataclasses.dataclass | |
class Config: | |
bind: str | |
port: int | |
databases: int | |
@dataclasses.dataclass | |
class Session: | |
config: Config | |
selected_database: int = 0 | |
@dataclasses.dataclass | |
class Command: | |
name: str | |
arguments_len: int | |
flags: List[str] | |
first_key: int | |
last_key: int | |
key_step: int | |
handler: Callable[[Session, Any], Any] | |
databases: List[Dict[bytes, Any]] = None | |
locks: List[asyncio.Lock] = None | |
def init_database(config: Config): | |
global databases | |
global locks | |
databases = [{} for _ in range(config.databases)] | |
locks = [asyncio.Lock() for _ in range(config.databases)] | |
async def handle_command_command(_: Session, __: Any) -> Any: | |
global commands | |
return [ | |
[command.name, command.arguments_len, command.flags, | |
command.first_key, command.last_key, command.key_step] | |
for command in commands.values() | |
] | |
async def handle_get_command(session: Session, args: Any) -> Any: | |
global databases | |
global locks | |
if not isinstance(args, list) or len(args) != 1: | |
return Error(code="ERR", message="wrong number of arguments for 'get' command") | |
key = args[0] | |
if not isinstance(key, bytes): | |
return Error(code="ERR", message="invalid key") | |
async with locks[session.selected_database]: | |
database = databases[session.selected_database] | |
return database.get(key) | |
async def handle_set_command(session: Session, args: Any) -> Any: | |
global databases | |
global locks | |
if not isinstance(args, list) or len(args) < 2: | |
return Error(code="ERR", message="wrong number of arguments for 'set' command") | |
key, value = args[0], args[1] | |
if not isinstance(key, bytes): | |
return Error(code="ERR", message="invalid key") | |
if not isinstance(value, bytes): | |
return Error(code="ERR", message="invalid value") | |
async with locks[session.selected_database]: | |
database = databases[session.selected_database] | |
database[key] = value | |
return b'OK' | |
async def handle_select_command(session: Session, args: Any) -> Any: | |
if not isinstance(args, list) or len(args) != 1: | |
return Error(code="ERR", message="wrong number of arguments for 'select' command") | |
target_db = args[0] | |
if not isinstance(target_db, bytes): | |
return Error(code="ERR", message="invalid DB index") | |
try: | |
target_db = int(target_db) | |
except: | |
return Error(code="ERR", message="invalid DB index") | |
if target_db >= session.config.databases: | |
return Error(code="ERR", message="DB index is out of range") | |
session.selected_database = target_db | |
return b'OK' | |
commands = { | |
b"COMMAND": Command( | |
name=b"COMMAND", | |
arguments_len=1, | |
flags=[b"readonly", b"random"], | |
first_key=1, | |
last_key=1, | |
key_step=1, | |
handler=handle_command_command, | |
), | |
b"GET": Command( | |
name=b"GET", | |
arguments_len=2, | |
flags=[b"readonly", b"random"], | |
first_key=1, | |
last_key=1, | |
key_step=1, | |
handler=handle_get_command, | |
), | |
b"SET": Command( | |
name=b"SET", | |
arguments_len=1, | |
flags=[b"write", b"string", b"slow"], | |
first_key=1, | |
last_key=1, | |
key_step=1, | |
handler=handle_set_command, | |
), | |
b"SELECT": Command( | |
name=b"SELECT", | |
arguments_len=2, | |
flags=[b"write", b"string", b"slow"], | |
first_key=0, | |
last_key=0, | |
key_step=0, | |
handler=handle_select_command, | |
) | |
} | |
def print_banner(): | |
print(" ___________________________________ ") | |
print("| _______________________________ |") | |
print("| | Redipy (A Simple Redis Clone) | |") | |
print("| |_______________________________| |") | |
print("|___________________________________|") | |
def load_config() -> Config: | |
bind = os.environ.get("REDIPY_BIND", "localhost") | |
port = int(os.environ.get("REDIPY_PORT", 5000)) | |
databases = int(os.environ.get("REDIPY_DATABASES", 16)) | |
return Config(bind=bind, port=port, databases=databases) | |
async def run_server(config: Config): | |
tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
server_address = (config.bind, config.port) | |
tcp_socket.bind(server_address) | |
tcp_socket.listen() | |
tcp_socket.setblocking(False) | |
loop = asyncio.get_event_loop() | |
logging.info("Server is listening to: %s:%d", config.bind, config.port) | |
while True: | |
connection, client = await loop.sock_accept(tcp_socket) | |
logging.info("New client is connected: %s", client) | |
loop.create_task(handle_client(connection, client, config)) | |
async def handle_client(connection: socket.socket, client: socket.AddressInfo, config: Config): | |
session = Session(config=config) | |
while True: | |
try: | |
await handle_client_request(session, connection) | |
except EOFError: | |
logging.info("Client is disconnected: %s", client) | |
return | |
async def handle_client_request(session: Session, connection: socket.socket): | |
request = await read_value(connection) | |
if not isinstance(request, list): | |
logging.warn("Invalid request from client: %s", request) | |
return | |
if len(request) == 0: | |
logging.warn("Invalid request from client: %s", request) | |
return | |
if not isinstance(request[0], bytes): | |
logging.warn("Invalid request from client: %s", request) | |
return | |
logging.debug("Got request: %s", request) | |
command, args = request[0].upper(), request[1:] | |
if command not in commands: | |
msg = "unknown command `{}`, with args beginning with: {}".format( | |
command, args) | |
err = Error(code="ERR", message=msg) | |
await write_value(connection, err) | |
return | |
result = await commands[command].handler(session, args) | |
await write_value(connection, result) | |
async def read_value(connection: socket.socket) -> Any: | |
loop = asyncio.get_event_loop() | |
marker = await loop.sock_recv(connection, 1) | |
if marker == b'*': | |
return await read_array(connection) | |
elif marker == b'$': | |
return await read_bulk_string(connection) | |
elif marker == b'+': | |
return await read_simple_string(connection) | |
elif marker == b':': | |
return await read_number(connection) | |
elif marker == b'': | |
raise EOFError | |
async def read_array(connection: socket.socket) -> List[Any]: | |
array_len = await read_number(connection) | |
array_values = [None] * array_len | |
for i in range(array_len): | |
array_values[i] = await read_value(connection) | |
return array_values | |
async def read_bulk_string(connection: socket.socket) -> Union[None, bytes]: | |
loop = asyncio.get_event_loop() | |
str_len = await read_number(connection) | |
if str_len < 0: | |
return None | |
str_value = await loop.sock_recv(connection, str_len) | |
b = await loop.sock_recv(connection, 2) | |
assert (b == b'\r\n') | |
return str_value | |
async def read_simple_string(connection: socket.socket) -> bytes: | |
return await read_until_newline(connection) | |
async def read_number(connection: socket.socket) -> int: | |
return int(await read_until_newline(connection)) | |
async def read_until_newline(connection: socket.socket) -> bytes: | |
loop = asyncio.get_event_loop() | |
result = bytes() | |
while True: | |
chr = await loop.sock_recv(connection, 1) | |
if chr == b'\r': | |
break | |
result += chr | |
b = await loop.sock_recv(connection, 1) | |
assert (b == b'\n') | |
return result | |
async def write_value(connection: socket.socket, value: Any): | |
if value is None: | |
await write_null_value(connection) | |
elif isinstance(value, bytes): | |
await write_string_value(connection, value) | |
elif isinstance(value, int): | |
await write_number_value(connection, value) | |
elif isinstance(value, list): | |
await write_array_value(connection, value) | |
elif isinstance(value, Error): | |
await write_error_value(connection, value) | |
else: | |
logging.error( | |
"Found unrecognized value with type (%s): %s", type(value), value | |
) | |
value = Error(code="ERR", message="internal server error") | |
await write_error_value(connection, value) | |
async def write_null_value(connection: socket.socket): | |
loop = asyncio.get_event_loop() | |
connection.sendall(b'$-1\r\n') | |
async def write_string_value(connection: socket.socket, value: bytes): | |
loop = asyncio.get_event_loop() | |
await loop.sock_sendall(connection, b'$') | |
await loop.sock_sendall(connection, bytes(str(len(value)), 'ascii')) | |
await loop.sock_sendall(connection, b'\r\n') | |
await loop.sock_sendall(connection, value) | |
await loop.sock_sendall(connection, b'\r\n') | |
async def write_number_value(connection: socket.socket, value: int): | |
loop = asyncio.get_event_loop() | |
await loop.sock_sendall(connection, b':') | |
await loop.sock_sendall(connection, bytes(str(value), 'ascii')) | |
await loop.sock_sendall(connection, b'\r\n') | |
async def write_array_value(connection: socket.socket, value: List[Any]): | |
loop = asyncio.get_event_loop() | |
await loop.sock_sendall(connection, b'*') | |
await loop.sock_sendall(connection, bytes(str(len(value)), 'ascii')) | |
await loop.sock_sendall(connection, b'\r\n') | |
for element in value: | |
await write_value(connection, element) | |
async def write_error_value(connection: socket.socket, value: Error): | |
loop = asyncio.get_event_loop() | |
await loop.sock_sendall(connection, b'-') | |
await loop.sock_sendall(connection, bytes(value.code, 'ascii')) | |
await loop.sock_sendall(connection, b' ') | |
await loop.sock_sendall(connection, bytes(value.message, 'ascii')) | |
await loop.sock_sendall(connection, b'\r\n') | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
print_banner() | |
config = load_config() | |
init_database(config) | |
try: | |
asyncio.run(run_server(config)) | |
except KeyboardInterrupt: | |
logging.info("Stopping") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment