-
-
Save coleifer/dbbedc287605dcc22990a6e549de9f36 to your computer and use it in GitHub Desktop.
Redis-like Socket Server with Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gevent import socket | |
from gevent.pool import Pool | |
from gevent.server import StreamServer | |
from collections import namedtuple | |
from io import BytesIO | |
from socket import error as socket_error | |
import logging | |
logger = logging.getLogger(__name__) | |
class CommandError(Exception): pass | |
class Disconnect(Exception): pass | |
Error = namedtuple('Error', ('message',)) | |
class ProtocolHandler(object): | |
def __init__(self): | |
self.handlers = { | |
'+': self.handle_simple_string, | |
'-': self.handle_error, | |
':': self.handle_integer, | |
'$': self.handle_string, | |
'*': self.handle_array, | |
'%': self.handle_dict} | |
def handle_request(self, socket_file): | |
first_byte = socket_file.read(1) | |
if not first_byte: | |
raise Disconnect() | |
try: | |
# Delegate to the appropriate handler based on the first byte. | |
return self.handlers[first_byte](socket_file) | |
except KeyError: | |
raise CommandError('bad request') | |
def handle_simple_string(self, socket_file): | |
return socket_file.readline().rstrip('\r\n') | |
def handle_error(self, socket_file): | |
return Error(socket_file.readline().rstrip('\r\n')) | |
def handle_integer(self, socket_file): | |
return int(socket_file.readline().rstrip('\r\n')) | |
def handle_string(self, socket_file): | |
# First read the length ($<length>\r\n). | |
length = int(socket_file.readline().rstrip('\r\n')) | |
if length == -1: | |
return None # Special-case for NULLs. | |
length += 2 # Include the trailing \r\n in count. | |
return socket_file.read(length)[:-2] | |
def handle_array(self, socket_file): | |
num_elements = int(socket_file.readline().rstrip('\r\n')) | |
return [self.handle_request(socket_file) for _ in range(num_elements)] | |
def handle_dict(self, socket_file): | |
num_items = int(socket_file.readline().rstrip('\r\n')) | |
elements = [self.handle_request(socket_file) | |
for _ in range(num_items * 2)] | |
return dict(zip(elements[::2], elements[1::2])) | |
def write_response(self, socket_file, data): | |
buf = BytesIO() | |
self._write(buf, data) | |
buf.seek(0) | |
socket_file.write(buf.getvalue()) | |
socket_file.flush() | |
def _write(self, buf, data): | |
if isinstance(data, str): | |
data = data.encode('utf-8') | |
if isinstance(data, bytes): | |
buf.write('$%s\r\n%s\r\n' % (len(data), data)) | |
elif isinstance(data, int): | |
buf.write(':%s\r\n' % data) | |
elif isinstance(data, Error): | |
buf.write('-%s\r\n' % error.message) | |
elif isinstance(data, (list, tuple)): | |
buf.write('*%s\r\n' % len(data)) | |
for item in data: | |
self._write(buf, item) | |
elif isinstance(data, dict): | |
buf.write('%%%s\r\n' % len(data)) | |
for key in data: | |
self._write(buf, key) | |
self._write(buf, data[key]) | |
elif data is None: | |
buf.write('$-1\r\n') | |
else: | |
raise CommandError('unrecognized type: %s' % type(data)) | |
class Server(object): | |
def __init__(self, host='127.0.0.1', port=31337, max_clients=64): | |
self._pool = Pool(max_clients) | |
self._server = StreamServer( | |
(host, port), | |
self.connection_handler, | |
spawn=self._pool) | |
self._protocol = ProtocolHandler() | |
self._kv = {} | |
self._commands = self.get_commands() | |
def get_commands(self): | |
return { | |
'GET': self.get, | |
'SET': self.set, | |
'DELETE': self.delete, | |
'FLUSH': self.flush, | |
'MGET': self.mget, | |
'MSET': self.mset} | |
def connection_handler(self, conn, address): | |
logger.info('Connection received: %s:%s' % address) | |
# Convert "conn" (a socket object) into a file-like object. | |
socket_file = conn.makefile('rwb') | |
# Process client requests until client disconnects. | |
while True: | |
try: | |
data = self._protocol.handle_request(socket_file) | |
except Disconnect: | |
logger.info('Client went away: %s:%s' % address) | |
break | |
try: | |
resp = self.get_response(data) | |
except CommandError as exc: | |
logger.exception('Command error') | |
resp = Error(exc.args[0]) | |
self._protocol.write_response(socket_file, resp) | |
def run(self): | |
self._server.serve_forever() | |
def get_response(self, data): | |
if not isinstance(data, list): | |
try: | |
data = data.split() | |
except: | |
raise CommandError('Request must be list or simple string.') | |
if not data: | |
raise CommandError('Missing command') | |
command = data[0].upper() | |
if command not in self._commands: | |
raise CommandError('Unrecognized command: %s' % command) | |
else: | |
logger.debug('Received %s', command) | |
return self._commands[command](*data[1:]) | |
def get(self, key): | |
return self._kv.get(key) | |
def set(self, key, value): | |
self._kv[key] = value | |
return 1 | |
def delete(self, key): | |
if key in self._kv: | |
del self._kv[key] | |
return 1 | |
return 0 | |
def flush(self): | |
kvlen = len(self._kv) | |
self._kv.clear() | |
return kvlen | |
def mget(self, *keys): | |
return [self._kv.get(key) for key in keys] | |
def mset(self, *items): | |
data = zip(items[::2], items[1::2]) | |
for key, value in data: | |
self._kv[key] = value | |
return len(data) | |
class Client(object): | |
def __init__(self, host='127.0.0.1', port=31337): | |
self._protocol = ProtocolHandler() | |
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
self._socket.connect((host, port)) | |
self._fh = self._socket.makefile('rwb') | |
def execute(self, *args): | |
self._protocol.write_response(self._fh, args) | |
resp = self._protocol.handle_request(self._fh) | |
if isinstance(resp, Error): | |
raise CommandError(resp.message) | |
return resp | |
def get(self, key): | |
return self.execute('GET', key) | |
def set(self, key, value): | |
return self.execute('SET', key, value) | |
def delete(self, key): | |
return self.execute('DELETE', key) | |
def flush(self): | |
return self.execute('FLUSH') | |
def mget(self, *keys): | |
return self.execute('MGET', *keys) | |
def mset(self, *items): | |
return self.execute('MSET', *items) | |
if __name__ == '__main__': | |
from gevent import monkey; monkey.patch_all() | |
logger.addHandler(logging.StreamHandler()) | |
logger.setLevel(logging.DEBUG) | |
Server().run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment