Skip to content

Instantly share code, notes, and snippets.

@jauhararifin
Last active January 20, 2023 04:31
Show Gist options
  • Save jauhararifin/13dcf59c963c86c3059b9a5d5e747357 to your computer and use it in GitHub Desktop.
Save jauhararifin/13dcf59c963c86c3059b9a5d5e747357 to your computer and use it in GitHub Desktop.
Super Simple Redis Clone In Python
import socket
import os
import logging
import dataclasses
import asyncio
from typing import Callable, Dict, List, Any, Union
@dataclasses.dataclass
class Error:
code: str
message: str
@dataclasses.dataclass
class Config:
bind: str
port: int
databases: int
@dataclasses.dataclass
class Session:
config: Config
selected_database: int = 0
@dataclasses.dataclass
class Command:
name: str
arguments_len: int
flags: List[str]
first_key: int
last_key: int
key_step: int
handler: Callable[[Session, Any], Any]
databases: List[Dict[bytes, Any]] = None
locks: List[asyncio.Lock] = None
def init_database(config: Config):
global databases
global locks
databases = [{} for _ in range(config.databases)]
locks = [asyncio.Lock() for _ in range(config.databases)]
async def handle_command_command(_: Session, __: Any) -> Any:
global commands
return [
[command.name, command.arguments_len, command.flags,
command.first_key, command.last_key, command.key_step]
for command in commands.values()
]
async def handle_get_command(session: Session, args: Any) -> Any:
global databases
global locks
if not isinstance(args, list) or len(args) != 1:
return Error(code="ERR", message="wrong number of arguments for 'get' command")
key = args[0]
if not isinstance(key, bytes):
return Error(code="ERR", message="invalid key")
async with locks[session.selected_database]:
database = databases[session.selected_database]
return database.get(key)
async def handle_set_command(session: Session, args: Any) -> Any:
global databases
global locks
if not isinstance(args, list) or len(args) < 2:
return Error(code="ERR", message="wrong number of arguments for 'set' command")
key, value = args[0], args[1]
if not isinstance(key, bytes):
return Error(code="ERR", message="invalid key")
if not isinstance(value, bytes):
return Error(code="ERR", message="invalid value")
async with locks[session.selected_database]:
database = databases[session.selected_database]
database[key] = value
return b'OK'
async def handle_select_command(session: Session, args: Any) -> Any:
if not isinstance(args, list) or len(args) != 1:
return Error(code="ERR", message="wrong number of arguments for 'select' command")
target_db = args[0]
if not isinstance(target_db, bytes):
return Error(code="ERR", message="invalid DB index")
try:
target_db = int(target_db)
except:
return Error(code="ERR", message="invalid DB index")
if target_db >= session.config.databases:
return Error(code="ERR", message="DB index is out of range")
session.selected_database = target_db
return b'OK'
commands = {
b"COMMAND": Command(
name=b"COMMAND",
arguments_len=1,
flags=[b"readonly", b"random"],
first_key=1,
last_key=1,
key_step=1,
handler=handle_command_command,
),
b"GET": Command(
name=b"GET",
arguments_len=2,
flags=[b"readonly", b"random"],
first_key=1,
last_key=1,
key_step=1,
handler=handle_get_command,
),
b"SET": Command(
name=b"SET",
arguments_len=1,
flags=[b"write", b"string", b"slow"],
first_key=1,
last_key=1,
key_step=1,
handler=handle_set_command,
),
b"SELECT": Command(
name=b"SELECT",
arguments_len=2,
flags=[b"write", b"string", b"slow"],
first_key=0,
last_key=0,
key_step=0,
handler=handle_select_command,
)
}
def print_banner():
print(" ___________________________________ ")
print("| _______________________________ |")
print("| | Redipy (A Simple Redis Clone) | |")
print("| |_______________________________| |")
print("|___________________________________|")
def load_config() -> Config:
bind = os.environ.get("REDIPY_BIND", "localhost")
port = int(os.environ.get("REDIPY_PORT", 5000))
databases = int(os.environ.get("REDIPY_DATABASES", 16))
return Config(bind=bind, port=port, databases=databases)
async def run_server(config: Config):
tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_address = (config.bind, config.port)
tcp_socket.bind(server_address)
tcp_socket.listen()
tcp_socket.setblocking(False)
loop = asyncio.get_event_loop()
logging.info("Server is listening to: %s:%d", config.bind, config.port)
while True:
connection, client = await loop.sock_accept(tcp_socket)
logging.info("New client is connected: %s", client)
loop.create_task(handle_client(connection, client, config))
async def handle_client(connection: socket.socket, client: socket.AddressInfo, config: Config):
session = Session(config=config)
while True:
try:
await handle_client_request(session, connection)
except EOFError:
logging.info("Client is disconnected: %s", client)
return
async def handle_client_request(session: Session, connection: socket.socket):
request = await read_value(connection)
if not isinstance(request, list):
logging.warn("Invalid request from client: %s", request)
return
if len(request) == 0:
logging.warn("Invalid request from client: %s", request)
return
if not isinstance(request[0], bytes):
logging.warn("Invalid request from client: %s", request)
return
logging.debug("Got request: %s", request)
command, args = request[0].upper(), request[1:]
if command not in commands:
msg = "unknown command `{}`, with args beginning with: {}".format(
command, args)
err = Error(code="ERR", message=msg)
await write_value(connection, err)
return
result = await commands[command].handler(session, args)
await write_value(connection, result)
async def read_value(connection: socket.socket) -> Any:
loop = asyncio.get_event_loop()
marker = await loop.sock_recv(connection, 1)
if marker == b'*':
return await read_array(connection)
elif marker == b'$':
return await read_bulk_string(connection)
elif marker == b'+':
return await read_simple_string(connection)
elif marker == b':':
return await read_number(connection)
elif marker == b'':
raise EOFError
async def read_array(connection: socket.socket) -> List[Any]:
array_len = await read_number(connection)
array_values = [None] * array_len
for i in range(array_len):
array_values[i] = await read_value(connection)
return array_values
async def read_bulk_string(connection: socket.socket) -> Union[None, bytes]:
loop = asyncio.get_event_loop()
str_len = await read_number(connection)
if str_len < 0:
return None
str_value = await loop.sock_recv(connection, str_len)
b = await loop.sock_recv(connection, 2)
assert (b == b'\r\n')
return str_value
async def read_simple_string(connection: socket.socket) -> bytes:
return await read_until_newline(connection)
async def read_number(connection: socket.socket) -> int:
return int(await read_until_newline(connection))
async def read_until_newline(connection: socket.socket) -> bytes:
loop = asyncio.get_event_loop()
result = bytes()
while True:
chr = await loop.sock_recv(connection, 1)
if chr == b'\r':
break
result += chr
b = await loop.sock_recv(connection, 1)
assert (b == b'\n')
return result
async def write_value(connection: socket.socket, value: Any):
if value is None:
await write_null_value(connection)
elif isinstance(value, bytes):
await write_string_value(connection, value)
elif isinstance(value, int):
await write_number_value(connection, value)
elif isinstance(value, list):
await write_array_value(connection, value)
elif isinstance(value, Error):
await write_error_value(connection, value)
else:
logging.error(
"Found unrecognized value with type (%s): %s", type(value), value
)
value = Error(code="ERR", message="internal server error")
await write_error_value(connection, value)
async def write_null_value(connection: socket.socket):
loop = asyncio.get_event_loop()
connection.sendall(b'$-1\r\n')
async def write_string_value(connection: socket.socket, value: bytes):
loop = asyncio.get_event_loop()
await loop.sock_sendall(connection, b'$')
await loop.sock_sendall(connection, bytes(str(len(value)), 'ascii'))
await loop.sock_sendall(connection, b'\r\n')
await loop.sock_sendall(connection, value)
await loop.sock_sendall(connection, b'\r\n')
async def write_number_value(connection: socket.socket, value: int):
loop = asyncio.get_event_loop()
await loop.sock_sendall(connection, b':')
await loop.sock_sendall(connection, bytes(str(value), 'ascii'))
await loop.sock_sendall(connection, b'\r\n')
async def write_array_value(connection: socket.socket, value: List[Any]):
loop = asyncio.get_event_loop()
await loop.sock_sendall(connection, b'*')
await loop.sock_sendall(connection, bytes(str(len(value)), 'ascii'))
await loop.sock_sendall(connection, b'\r\n')
for element in value:
await write_value(connection, element)
async def write_error_value(connection: socket.socket, value: Error):
loop = asyncio.get_event_loop()
await loop.sock_sendall(connection, b'-')
await loop.sock_sendall(connection, bytes(value.code, 'ascii'))
await loop.sock_sendall(connection, b' ')
await loop.sock_sendall(connection, bytes(value.message, 'ascii'))
await loop.sock_sendall(connection, b'\r\n')
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
print_banner()
config = load_config()
init_database(config)
try:
asyncio.run(run_server(config))
except KeyboardInterrupt:
logging.info("Stopping")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment