Skip to content

Instantly share code, notes, and snippets.

@afrase
Created May 20, 2016 19:29
Show Gist options
  • Save afrase/cc8a36f02b67d3bdf3e576d586fd051f to your computer and use it in GitHub Desktop.
Save afrase/cc8a36f02b67d3bdf3e576d586fd051f to your computer and use it in GitHub Desktop.
Python HTTP server using coroutines
# -*- coding: utf-8 -*-
import threading
from collections import deque
from functools import wraps
from select import select
from socket import socket, AF_INET, SOCK_STREAM
from urllib.request import urlopen
class TaskException(Exception):
pass
class Scheduler(object):
"""I/O driven task scheduler"""
def __init__(self):
super(Scheduler, self).__init__()
self._numtasks = 0
self._ready = deque()
self._read_waiting = {}
self._write_waiting = {}
def _iopoll(self):
"""Poll for I/O events and restart waiting tasks
Notes:
`select` will block until some sort of connection is made or data
is received from the socket.
"""
rset, wset, eset = select(self._read_waiting, self._write_waiting, [])
for r in rset:
evt, task = self._read_waiting.pop(r)
evt.handle_resume(self, task)
for w in wset:
evt, task = self._write_waiting.pop(w)
evt.handle_resume(self, task)
def new(self, task):
"""Add a newly started task to the scheduler
Args:
task (object): the task to add to the scheduler.
"""
if task is None:
raise TaskException('Task cannot be None')
self._ready.append((task, None))
self._numtasks += 1
def add_ready(self, task, msg=None):
"""Append an already started task to the ready queue.
Args:
task (object): The task to add.
msg (object): What to send into the task when it resumes.
"""
if task is None:
raise TaskException('Task cannot be None')
self._ready.append((task, msg))
def add_read(self, fileno, evt, task):
"""Add a task to the reading set
Args:
fileno (int):
evt (YieldEvent):
task (generator):
"""
self._read_waiting[fileno] = (evt, task)
def add_write(self, fileno, evt, task):
"""Add a task to the write set
Args:
fileno (int):
evt (YieldEvent):
task (generator):
"""
self._write_waiting[fileno] = (evt, task)
def run(self):
"""Run the task scheduler unit there are no tasks.
"""
while self._numtasks:
if not self._ready:
self._iopoll()
task, msg = self._ready.popleft()
if isinstance(task, _SentinelEvent):
break
try:
# run the coroutine to the next yield
r = task.send(msg)
if isinstance(r, YieldEvent):
r.handle_yield(self, task)
else:
raise RuntimeError('unrecognized yield event')
except StopIteration:
self._numtasks -= 1
class YieldEvent(object):
def handle_yield(self, sched, task):
pass
def handle_resume(self, sched, task):
pass
class ReadSocket(YieldEvent):
def __init__(self, sock, nbytes):
super(ReadSocket, self).__init__()
self.sock = sock
self.nbytes = nbytes
def handle_yield(self, sched, task):
sched.add_read(self.sock.fileno(), self, task)
def handle_resume(self, sched, task):
data = self.sock.recv(self.nbytes)
sched.add_ready(task, data)
class WriteSocket(YieldEvent):
def __init__(self, sock, data):
super(WriteSocket, self).__init__()
self.sock = sock
self.data = data
def handle_yield(self, sched, task):
sched.add_write(self.sock.fileno(), self, task)
def handle_resume(self, sched, task):
nsent = self.sock.send(self.data)
sched.add_ready(task, nsent)
class AcceptSocket(YieldEvent):
def __init__(self, sock):
super(AcceptSocket, self).__init__()
self.sock = sock
def handle_yield(self, sched, task):
sched.add_read(self.sock.fileno(), self, task)
def handle_resume(self, sched, task):
r = self.sock.accept()
sched.add_ready(task, r)
class CloseSocket(YieldEvent):
def __init__(self, sock):
super(CloseSocket, self).__init__()
self.sock = sock
def handle_yield(self, sched, task):
self.sock.close()
class _SentinelEvent(YieldEvent):
pass
class Socket(object):
"""Wrapper around a socket object for use with yield
"""
def __init__(self, sock):
super(Socket, self).__init__()
self._sock = sock
def recv(self, maxbytes):
return ReadSocket(self._sock, maxbytes)
def send(self, data):
return WriteSocket(self._sock, data)
def accept(self):
return AcceptSocket(self._sock)
def close(self):
return CloseSocket(self._sock)
def __getattr__(self, name):
return getattr(self._sock, name)
class HTTPRequest(object):
def __init__(self, sock):
super(HTTPRequest, self).__init__()
self._sock = sock
self.headers = {}
self.body = ''
self.method = None
self.request_uri = None
self.http_version = None
def read(self):
while True:
line = yield from readline(self._sock)
if not line or line == b'\r\n':
break
self._parse_header_line(line.decode('utf-8'))
if self.method == 'POST':
yield from self._read_body()
def _read_body(self):
content_length = 0
header_content_length = int(self.headers['Content-Length'])
while content_length < header_content_length:
length_diff = header_content_length - content_length
line = yield from readline(self._sock, length_diff)
content_length += len(line)
self.body += line.decode('utf-8')
def _parse_header_line(self, line):
# first line of an HTTP request will always be
# "Method SP Request-URI SP HTTP-Version CRLF" according to rfc1945
if self.method is None:
self.method, self.request_uri, self.http_version = line.split(' ')
else:
k, *v = line.split(':')
self.headers[k] = ':'.join(v).strip()
class HTTPResponse(object):
def __init__(self, sock, data=None):
super(HTTPResponse, self).__init__()
self._sock = sock
if data is None:
self.data = ''
else:
self.data = data
self.status_code = 200
self.http_version = 'HTTP/1.1'
self.status_msg = 'OK'
self.headers = {}
def send(self):
self.headers['Content-Length'] = len(self.data)
self.headers['Connection'] = 'close'
status_line = "{} {} {}\r\n".format(
self.http_version, self.status_code, self.status_msg)
response_headers = "\r\n".join(['{}: {}'.format(k, v)
for k, v in self.headers.items()])
# must have two newlines before the data
body = ("\r\n" * 2) + self.data
response = ''.join([status_line, response_headers, body]).encode()
while response:
nsent = yield self._sock.send(response)
response = response[nsent:]
yield self._sock.close()
class HTTPServer(object):
def __init__(self, addr, sched, event, handler=None):
"""
Args:
addr (tuple[str,int]):
sched (Scheduler):
event (threading.Event):
handler (Generator):
"""
super(HTTPServer, self).__init__()
self.num_connections = 0
self._scheduler = sched
self._addr = addr
self.event = event
self._scheduler.new(self.server_loop(addr))
if handler is None:
def null_handler(sock):
yield sock.close()
self.client_handler = null_handler
else:
self.client_handler = handler
def server_loop(self, addr):
"""
Args:
addr (tuple[str,int]):
"""
s = Socket(socket(AF_INET, SOCK_STREAM))
s.bind(addr)
s.listen(5)
while True:
# wait for a client to connect
c, a = yield s.accept()
if self.event.is_set():
self._scheduler.new(_SentinelEvent())
break
else:
self.num_connections += 1
self._scheduler.new(self.client_handler(Socket(c)))
def run(self):
print('Starting server on port %s' % self._addr[1])
self._scheduler.run()
def readline(sock, amt=None):
"""Read the content of `sock` 1 byte at a time, yielding for each byte.
Args:
sock (socket.Socket): The socket to read from.
amt (int): The number of bytes to read from `sock`. If `amt` is None
then it will read until a \n is found.
Returns:
bytearray: Contents of the socket
"""
chars = []
total_read = 0
while amt is None or total_read < amt:
c = yield sock.recv(1)
if not c:
break
chars.append(c)
total_read += len(c)
if c == b'\n':
break
return b''.join(chars)
def http_handler(func):
@wraps(func)
def wrapper(sock):
req = HTTPRequest(sock)
resp = HTTPResponse(sock)
yield from req.read()
try:
func(req, resp)
except:
resp.status_code = 500
resp.status_msg = 'Server Error'
yield from resp.send()
return wrapper
def build_http_server(port, ip_address=None, handler_func=None):
if ip_address is None:
ip_address = ''
scheduler = Scheduler()
event = threading.Event()
srvr = HTTPServer((ip_address, port), scheduler, event, handler_func)
return srvr
@http_handler
def echo_handler(request, response):
response.data = request.body
if __name__ == '__main__':
server_port = 8080
server = build_http_server(server_port, handler_func=echo_handler)
thread = threading.Thread(target=server.run)
thread.start()
while True:
i = input()
if i in ('stop', 'quit'):
server.event.set()
try:
urlopen('http://localhost:%s' % server_port)
except Exception as e:
print(e)
thread.join()
break
else:
try:
r = urlopen('http://localhost:%s' % server_port, i.encode())
print(r.read().decode())
except Exception as e:
print(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment