Last active
February 3, 2024 01:55
-
-
Save FelixWolf/742e0bcb113929fe28942f75dc585f5b to your computer and use it in GitHub Desktop.
Web server with websocket support written in python, zlib or unlicense / public domain (Take your pick)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Description: | |
A simple web single file server with websocket support. | |
No HTTPS yet. | |
LICENSE: | |
ZLib or Unlicense, take your pick. | |
Option 1: | |
------------------------------------ ZLIB -------------------------------------- | |
Copyright (c) 2022 Kyler "Félix" Eastridge | |
This software is provided 'as-is', without any express or implied | |
warranty. In no event will the authors be held liable for any damages | |
arising from the use of this software. | |
Permission is granted to anyone to use this software for any purpose, | |
including commercial applications, and to alter it and redistribute it | |
freely, subject to the following restrictions: | |
1. The origin of this software must not be misrepresented; you must not | |
claim that you wrote the original software. If you use this software | |
in a product, an acknowledgment in the product documentation would be | |
appreciated but is not required. | |
2. Altered source versions must be plainly marked as such, and must not be | |
misrepresented as being the original software. | |
3. This notice may not be removed or altered from any source distribution. | |
-------------------------------------------------------------------------------- | |
Option 2: | |
--------------------------------- UNLICENSE ------------------------------------ | |
This is free and unencumbered software released into the public domain. | |
Anyone is free to copy, modify, publish, use, compile, sell, or | |
distribute this software, either in source code form or as a compiled | |
binary, for any purpose, commercial or non-commercial, and by any | |
means. | |
In jurisdictions that recognize copyright laws, the author or authors | |
of this software dedicate any and all copyright interest in the | |
software to the public domain. We make this dedication for the benefit | |
of the public at large and to the detriment of our heirs and | |
successors. We intend this dedication to be an overt act of | |
relinquishment in perpetuity of all present and future rights to this | |
software under copyright law. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. | |
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR | |
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR | |
OTHER DEALINGS IN THE SOFTWARE. | |
For more information, please refer to <http://unlicense.org/> | |
-------------------------------------------------------------------------------- | |
""" | |
import asyncio | |
import base64 | |
import hashlib | |
import random | |
import struct | |
import traceback | |
import logging | |
import typing | |
import io | |
import uuid | |
import time | |
import ssl | |
import urllib.parse | |
logger = logging.getLogger(__name__) | |
def chunked(data, size): | |
for i in range(0, len(data), size): | |
yield data[i:i+size] | |
class HTTPException(Exception): | |
pass | |
class HTTPRequest: | |
METHODS = ("GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", | |
"TRACE", "PATCH") | |
def __init__(self, method = None, path = None, version = None, headers = None, | |
reader = None, writer = None): | |
self.method = method or None | |
self.path = path or None | |
self.version = version or (None, None) | |
self.headers = headers or [] | |
self.reader = reader or None | |
self.writer = writer or None | |
self.query = {} | |
self.remote_addr = None | |
def getStartLine(self): | |
return "{} {} HTTP/{}".format(self.method, self.path, "{}.{}".format(*self.version)) | |
def parseStartLine(self, data): | |
#Get method | |
if " " not in data: | |
raise HTTPException("Poorly formatted HTTP request") | |
method, data = data.split(" ", 1) | |
method = method.upper() | |
if method not in self.METHODS: | |
raise HTTPException("Invalid HTTP method") | |
#Get path | |
if " " not in data: | |
raise HTTPException("Poorly formatted HTTP request") | |
path, data = data.split(" ", 1) | |
if "?" in path: | |
path, query = path.split("?", 1) | |
query = urllib.parse.parse_qs(query, keep_blank_values = True) | |
else: | |
query = {} | |
path = urllib.parse.unquote(path) | |
#Get version | |
if not data.startswith("HTTP/"): | |
raise HTTPException("Not a HTTP request") | |
##This is safe because we confirmed "/" is in the version in the previous statement | |
_, data = data.split("/", 1) | |
if "." not in data: | |
raise HTTPException("HTTP version malformed") | |
major, minor = data.split(".") | |
try: | |
version = (int(major), int(minor)) | |
except ValueError: | |
raise HTTPException("HTTP version contains non-integer values") | |
self.method = method | |
self.path = path | |
self.query = query | |
self.version = version | |
def parseHeader(self, data): | |
if ":" not in data: | |
raise HTTPException("Poorly formatted HTTP header") | |
name, value = data.split(":", 1) | |
if value.startswith(" "): | |
value = value[1:] | |
self.headers.append((name, value)) | |
def getHeader(self, name, all = False): | |
result = [] | |
for k, v in self.headers: | |
if k.lower() == name.lower(): | |
if not all: | |
return v | |
result.append(v) | |
if not all: | |
return None | |
else: | |
return result | |
def getQuery(self, key, default = None, all = False): | |
if not key in self.query and len(self.query[key]) > 0: | |
return default | |
if all: | |
return self.query[key] | |
else: | |
return self.query[0] | |
def __repr__(self): | |
return "<{cls} {method} \"{path}\" HTTP/{version}>".format( | |
cls = self.__class__.__name__, | |
method = self.method, | |
path = self.path, | |
version = "{}.{}".format(*self.version) | |
) | |
@classmethod | |
async def fromConnection(cls, reader, writer, timeout = None, maxHeaderLength = 1024 * 64): | |
self = cls(reader = reader, writer = writer) | |
data = (await asyncio.wait_for( | |
reader.readuntil(separator=b'\r\n'), | |
timeout = timeout | |
))[:-2] | |
self.parseStartLine(data.decode("ascii")) | |
headerLen = 0 | |
while True: | |
data = (await asyncio.wait_for( | |
reader.readuntil(separator=b'\r\n'), | |
timeout = timeout | |
))[:-2] | |
if len(data) == 0: | |
break | |
headerLen += (len(data) + 2) | |
if headerLen > maxHeaderLength: | |
raise HTTPException("Header is too large") | |
self.parseHeader(data.decode("ascii")) | |
""" | |
"AUTH_TYPE" | "CONTENT_LENGTH" | | |
"CONTENT_TYPE" | "GATEWAY_INTERFACE" | | |
"PATH_INFO" | "PATH_TRANSLATED" | | |
"QUERY_STRING" | "REMOTE_ADDR" | | |
"REMOTE_HOST" | "REMOTE_IDENT" | | |
"REMOTE_USER" | "REQUEST_METHOD" | | |
"SCRIPT_NAME" | "SERVER_NAME" | | |
"SERVER_PORT" | "SERVER_PROTOCOL" | | |
"SERVER_SOFTWARE" | |
""" | |
self.remote_addr = self.writer.get_extra_info('peername') | |
return self | |
class HTTPResponse: | |
STATUS_MESSAGES = { | |
100: "Continue", | |
101: "Switching Protocols", | |
102: "Processing", | |
103: "Early Hints", | |
200: "OK", | |
201: "Created", | |
202: "Accepted", | |
203: "Non-Authoritative Information", | |
204: "No Content", | |
205: "Reset Content", | |
206: "Partial Content", | |
207: "Multi-Status", | |
208: "Already Reported", | |
226: "IM Used", | |
300: "Multiple Choices", | |
301: "Moved Permanently", | |
302: "Found", | |
303: "See Other", | |
304: "Not Modified", | |
307: "Temporary Redirect", | |
308: "Permanent Redirect", | |
400: "Bad Request", | |
401: "Unauthorized", | |
402: "Payment Required", | |
403: "Forbidden", | |
404: "Not Found", | |
405: "Method Not Allowed", | |
406: "Not Acceptable", | |
407: "Proxy Authentication Required", | |
408: "Request Timeout", | |
409: "Conflict", | |
410: "Gone", | |
411: "Length Required", | |
412: "Precondition Failed", | |
413: "Content Too Large", | |
414: "URI Too Long", | |
415: "Unsupported Media Type", | |
416: "Range Not Satisfiable", | |
417: "Expectation Failed", | |
418: "I'm a teapot", | |
421: "Misdirected Request", | |
422: "Unprocessable Content", | |
423: "Locked", | |
424: "Failed Dependency", | |
425: "Too Early", | |
426: "Upgrade Required", | |
428: "Precondition Required", | |
429: "Too Many Requests", | |
431: "Request Header Fields Too Large", | |
451: "Unavailable For Legal Reasons", | |
500: "Internal Server Error", | |
501: "Not Implemented", | |
502: "Bad Gateway", | |
503: "Service Unavailable", | |
504: "Gateway Timeout", | |
505: "HTTP Version Not Supported", | |
506: "Variant Also Negotiates", | |
507: "Insufficient Storage", | |
508: "Loop Detected", | |
510: "Not Extended", | |
511: "Network Authentication Required", | |
} | |
def __init__(self, request, status = None, message = None, | |
headers = None, version = None): | |
self.request = request | |
self.status = status or 200 | |
self._message = message or None | |
self.headers = headers or [] | |
self.version = version or (1, 1) | |
self.headersSent = False | |
@property | |
def message(self): | |
if self._message: | |
return self._message | |
if self.status in self.STATUS_MESSAGES: | |
return self.STATUS_MESSAGES[self.status] | |
return "Unknown Status" | |
@message.setter | |
def message(self, value): | |
self._message = value | |
@message.deleter | |
def message(self): | |
self._message = None | |
def deleteHeader(self, name): | |
while True: | |
for header in self.headers: | |
if header[0].lower() == name.lower(): | |
self.headers.remove(header) | |
break | |
else: | |
break | |
def addHeader(self, name, value): | |
self.headers.append((name, value)) | |
def setHeader(self, name, value): | |
self.deleteHeader(name) | |
self.addHeader(name, value) | |
async def writeHeaders(self, timeout = None): | |
if self.headersSent: | |
raise HTTPException("Cannot resend headers") | |
self.request.writer.write("HTTP/{} {} {}".format( | |
"{}.{}".format(*self.version), | |
self.status, | |
self.message | |
).encode()) | |
self.request.writer.write(b"\r\n") | |
await asyncio.wait_for(self.request.writer.drain(), timeout = timeout) | |
for k, v in self.headers: | |
self.request.writer.write(k.encode()) | |
self.request.writer.write(b": ") | |
self.request.writer.write(v.encode()) | |
self.request.writer.write(b"\r\n") | |
self.request.writer.write(b"\r\n") | |
await asyncio.wait_for(self.request.writer.drain(), timeout = timeout) | |
self.headersSent = True | |
async def write(self, data = None, drain = True, timeout = None): | |
if not self.headersSent: | |
await self.writeHeaders() | |
if data: | |
if type(data) == str: | |
data = data.encode() | |
self.request.writer.write(data) | |
if drain: | |
await asyncio.wait_for(self.request.writer.drain(), timeout = timeout) | |
def __repr__(self): | |
return "<{cls} {status} {message}>".format( | |
cls = self.__class__.__name__, | |
status = self.status, | |
message = self.message | |
) | |
class HTTPWebsocketException(Exception): | |
def __init__(self, *args, code = None, **kwargs): | |
self.closeCode = code | |
super().__init__(*args, **kwargs) | |
class HTTPWebsocketFrame: | |
BIT_FIN = 0x80 | |
BIT_RSV1 = 0x40 | |
BIT_RSV1 = 0x20 | |
BIT_OPCODE = 0x0F | |
BIT_MASK = 0x80 | |
BIT_SIZE = 0x7F | |
OP_CONTINUATION = 0x00 | |
OP_TEXT = 0x01 | |
OP_BINARY = 0x02 | |
OP_CLOSE = 0x08 | |
OP_PING = 0x09 | |
OP_PONG = 0x0A | |
CLOSE_NORMAL = 1000 | |
CLOSE_GOING_AWAY = 1001 | |
CLOSE_PROTOCOL_ERROR = 1002 | |
CLOSE_INVALID_DATA = 1003 | |
CLOSE_INCONSISTENT_DATA = 1007 | |
CLOSE_POLICY_VIOLATION = 1008 | |
CLOSE_TOO_BIG = 1009 | |
CLOSE_INTERNAL_ERROR = 1011 | |
CHUNK_SIZE = 4096 #This must be a multiple of 4 | |
STRUCT_HEADER = struct.Struct(">BB") | |
STRUCT_UINT16 = struct.Struct(">H") | |
STRUCT_UINT32 = struct.Struct(">I") | |
STRUCT_UINT64 = struct.Struct(">Q") | |
def __init__(self, opcode = None, body = None, mask = True): | |
self.opcode = opcode | |
self.body = body | |
self.mask = mask | |
@classmethod | |
async def fromReader(cls, reader, maxLength = 0xF_FFFF, timeout = None): | |
self = cls() | |
body = None | |
dType = None | |
opcode = None | |
while True: | |
data = await asyncio.wait_for(reader.readexactly(2), timeout = timeout) | |
head1, head2 = self.STRUCT_HEADER.unpack(data) | |
opcode = head1 & self.BIT_OPCODE | |
masked = head2 & self.BIT_MASK | |
if dType == None: | |
if opcode == self.OP_BINARY: | |
dType = bytes | |
elif opcode == self.OP_TEXT: | |
dType = str | |
elif opcode in (self.OP_CLOSE, self.OP_PING, self.OP_PONG): | |
dType = bytes | |
else: | |
raise HTTPWebsocketException("Invalid opcode", code = self.CLOSE_INVALID_DATA) | |
if dType == str: | |
body = io.StringIO() | |
elif dType == bytes: | |
body = io.BytesIO() | |
self.opcode = opcode | |
size = head2 & self.BIT_SIZE | |
if size == 126: | |
data = await asyncio.wait_for(reader.readexactly(2), timeout = timeout) | |
size, = self.STRUCT_UINT16.unpack(data) | |
elif (head2 & 0x7F) == 127: | |
data = await asyncio.wait_for(reader.readexactly(8), timeout = timeout) | |
size, = self.STRUCT_UINT64.unpack(data) | |
if masked: | |
mask = await asyncio.wait_for(reader.readexactly(4), timeout = timeout) | |
if size > maxLength: | |
raise HTTPWebsocketException("Payload is too big", code = self.CLOSE_TOO_BIG) | |
remaining = size | |
for i in range(0, size, self.CHUNK_SIZE): | |
data = await asyncio.wait_for(reader.readexactly(min(remaining, self.CHUNK_SIZE)), timeout = timeout) | |
if masked: | |
data = bytearray(data) | |
for ii in range(len(data)): | |
data[ii] ^= mask[ii&3] | |
if dType == bytes: | |
body.write(data) | |
elif dType == str: | |
body.write(data.decode()) | |
remaining -= len(data) | |
if head1 & self.BIT_FIN: | |
break | |
body.seek(0) | |
self.body = body.read() | |
return self | |
async def toWriter(self, writer, timeout = None): | |
t = type(self.body) | |
head1 = 0 | |
if t == bytes: | |
data = self.body | |
head1 = (self.opcode or self.OP_BINARY) | self.BIT_FIN | |
elif t == str: | |
data = self.body.encode() | |
head1 = (self.opcode or self.OP_TEXT) | self.BIT_FIN | |
elif t in (typing.Generator, typing.Iterator): | |
raise ValueError("{} isn't supported at this time".format(t)) | |
else: | |
raise ValueError("Cannot send type {}".format(t)) | |
head2 = 0 | |
mask = 0 | |
if self.mask: | |
if type(self.mask) == bool and self.mask == True: | |
#1 to 0xFFFFFFFF because 0 is effectively off | |
mask = random.randint(1, 0xFFFF_FFFF) | |
else: | |
mask = self.mask | |
head2 = self.BIT_MASK | |
if len(data) < 126: | |
head2 |= len(data) | |
elif len(data) <= 0xFFFF: | |
head2 |= 126 | |
elif len(data) <= 0x7FFF_FFFF_FFFF_FFFF: | |
head2 |= 127 | |
else: | |
raise ValueError("Payload is too big!") | |
writer.write(self.STRUCT_HEADER.pack(head1, head2)) | |
if (head2 & 0x7F) == 126: | |
writer.write(self.STRUCT_UINT16.pack(len(data))) | |
elif (head2 & 0x7F) == 127: | |
writer.write(self.STRUCT_UINT64.pack(len(data))) | |
if mask: | |
maskBytes = self.STRUCT_UINT32.pack(mask) | |
writer.write(maskBytes) | |
await asyncio.wait_for(writer.drain(), timeout = timeout) | |
#It is important for batch size to be a multiple of 4 | |
for chunk in chunked(data, 4096): | |
if mask: #Only mask if we need to | |
chunk = bytearray(chunk) | |
for i in range(0, len(chunk)): | |
# i%3 is faster than i%4 | |
chunk[i] ^= maskBytes[i&3] | |
writer.write(chunk) | |
await asyncio.wait_for(writer.drain(), timeout = timeout) | |
#Future: Will return data written for iterator/generator support | |
return len(data) | |
class StreamMemoryWriter: | |
def __init__(self, writeFunc, loop = None): | |
self._buffer = asyncio.streams.StreamReader() | |
self._writeFunc = writeFunc | |
def write(self, data): | |
print("WRITING", data) | |
self._buffer.feed_data(data) | |
def write_eof(self): | |
return self._buffer.feed_eof() | |
def can_write_eof(self): | |
return not self._buffer.at_eof() | |
def close(self): | |
return self.write_eof() | |
def is_closing(self): | |
return self._buffer.at_eof() | |
async def drain(self): | |
await self._writeFunc(await self._buffer.read(len(self._buffer._buffer))) | |
class HTTPWebsocket: | |
STRUCT_UINT16 = struct.Struct(">H") | |
def __init__(self, request, response): | |
self.request = request | |
self.response = response | |
self.closed = False | |
self.accepted = False | |
self._reader = None | |
self._readerTask = None | |
self._writer = None | |
self._writerTask = None | |
async def accept(self): | |
self.request.method = "WS" | |
requestKey = self.request.getHeader("Sec-Websocket-Key") | |
responseKey = base64.b64encode(hashlib.sha1( | |
(requestKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode() | |
).digest()).decode() | |
self.response.status = 101 | |
self.response.setHeader("Upgrade", "websocket") | |
self.response.setHeader("Connection", "Upgrade") | |
self.response.setHeader("Sec-WebSocket-Accept", responseKey) | |
await self.response.write() | |
self.accepted = True | |
async def reject(self, reason = None): | |
self.request.method = "WS" | |
self.closed = True | |
self.response.status = 403 | |
await self.response.write(reason or b"") | |
async def recv(self, maxLength = 0xFF_FFFF, timeout = 5): | |
if self.closed: | |
return None | |
try: | |
frame = await HTTPWebsocketFrame.fromReader(self.request.reader, maxLength, timeout) | |
return frame | |
except HTTPWebsocketException as e: | |
self.close(str(e), code = e.closeCode) | |
raise e | |
async def send(self, frame, timeout = 5): | |
if not self.closed: | |
await frame.toWriter(self.request.writer) | |
return True | |
return False | |
async def write(self, data, timeout = 5): | |
if type(data) == str: | |
opcode = HTTPWebsocketFrame.OP_TEXT | |
elif type(data) == bytes: | |
opcode = HTTPWebsocketFrame.OP_BINARY | |
else: | |
raise ValueError("Don't know how to encode type {}".format(type(data))) | |
frame = HTTPWebsocketFrame(opcode, data) | |
return await self.send(frame, timeout) | |
async def ping(self, tag = None, timeout = 5): | |
loop = asyncio.get_running_loop() | |
future = loop.create_future() | |
def timeout(tag, timeout): | |
time.sleep(5) | |
while True: | |
for ping in self.pings: | |
if ping == tag: | |
ping[1].set_exception(TimeoutError("Ping failed")) | |
self.pings.remove(ping) | |
break | |
else: | |
break | |
if not tag: | |
tag = str(uuid.uuid4()).encode() | |
self.pings.append((tag, future, time.time())) | |
frame = HTTPWebsocketFrame(HTTPWebsocketFrame.OP_PING, tag) | |
if not await self.send(frame, 5): | |
future.set_result(False) | |
else: | |
#Timeout after we confirmed to have sent | |
asyncio.ensure_future(timeout(tag, timeout)) | |
return await future | |
async def pong(self, tag = None, timeout = 5): | |
frame = HTTPWebsocketFrame(HTTPWebsocketFrame.OP_PONG, tag or b"") | |
return await self.send(frame, 5) | |
async def close(self, code = None, reason = None): | |
if self.closed: | |
raise HTTPException("Trying to close already closed websocket") | |
self.closed = True | |
if reason: | |
if type(reason) == str: | |
reason = reason.encode() | |
else: | |
reason = b"" | |
if self.accepted: | |
if not code: | |
code = HTTPWebsocketFrame.CLOSE_NORMAL | |
data = self.STRUCT_UINT16.pack(code) + reason | |
frame = HTTPWebsocketFrame(HTTPWebsocketFrame.OP_CLOSE, data) | |
await self.send(frame, 5) | |
logger.debug("Connection close from {}".format(self.request.remote_addr)) | |
logger.info("{} - \"{}\" {}".format( | |
self.request.remote_addr, | |
self.request.getStartLine(), | |
self.response.status | |
)) | |
self.request.writer.close() | |
await self.request.writer.wait_closed() | |
else: | |
logger.info("{} - \"{}\" {}".format( | |
self.request.remote_addr, | |
self.request.getStartLine(), | |
self.response.status | |
)) | |
logger.debug("Connection close from {}".format(self.request.remote_addr)) | |
self.request.writer.close() | |
await self.request.writer.wait_closed() | |
@property | |
def reader(self): | |
if self._reader: | |
return self._reader | |
self._reader = asyncio.streams.StreamReader() | |
async def readerTask(): | |
while not self.closed: | |
res = await self.recv() | |
if res.opcode == HTTPWebsocketFrame.OP_PING: | |
await self.pong(res.body) | |
elif res.opcode == HTTPWebsocketFrame.OP_CLOSE: | |
self._reader.feed_eof() | |
await self.close() | |
elif res.opcode == HTTPWebsocketFrame.OP_TEXT: | |
self._reader.feed_data(res.body.encode()) | |
elif res.opcode == HTTPWebsocketFrame.OP_BINARY: | |
self._reader.feed_data(res.body) | |
else: | |
raise HTTPWebsocketException("Reader doesn't implement handling for packet type {}!".format(res.opcode)) | |
self._readerTask = asyncio.create_task(readerTask()) | |
return self._reader | |
@property | |
def writer(self): | |
if self._writer: | |
return self._writer | |
async def writerTask(data): | |
await self.write(data) | |
self._writer = StreamMemoryWriter(writerTask) | |
return self._writer | |
class HTTPServer: | |
def __init__(self): | |
pass | |
async def start_server(self, host = "0.0.0.0", port = 80, ssl = None): | |
server = await asyncio.start_server(self.create_connection, host, port, ssl = ssl) | |
return server | |
async def run(self, *args, **kwargs): | |
instance = await self.start_server(*args, **kwargs) | |
return await instance.serve_forever() | |
async def create_connection(self, reader, writer): | |
asyncio.create_task(self.handle_connection(reader, writer)) | |
async def handle_connection(self, reader, writer): | |
peer = writer.get_extra_info('peername') | |
logger.debug("Connection opened from {}".format(peer)) | |
try: | |
request = await HTTPRequest.fromConnection(reader, writer) | |
except HTTPException: | |
logger.debug("Invalid HTTP connection from {}:\n".format(peer, traceback.format_exc())) | |
return | |
logger.info("{} - \"{}\"".format(request.remote_addr, request.getStartLine())) | |
try: | |
await self.handle_request(request) | |
except Exception as e: | |
#if not response.headersSent: | |
# response = HTTPResponse(request, 500) | |
# await response.write(b"Internal server error") | |
logger.error( | |
"{} - \"{}\" ERROR: {}".format( | |
request.remote_addr, | |
request.getStartLine(), | |
"\n".join([" "+line for line in traceback.format_exc().splitlines()]) | |
) | |
) | |
logger.debug("Connection close from {}".format(request.remote_addr)) | |
writer.close() | |
await writer.wait_closed() | |
async def handle_request(self, request): | |
response = HTTPResponse(request, 404) | |
await response.write() | |
class HTTPServerPaths(HTTPServer): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.paths = {} | |
async def handle_request(self, request): | |
if (request.method, request.path) in self.paths: | |
response = HTTPResponse(request, 200) | |
await self.paths[(request.method, request.path)](request, response) | |
await response.write() | |
elif request.method == "GET" \ | |
and request.getHeader("connection", "").lower() == "upgrade" \ | |
and request.getHeader("upgrade", "").lower() == "websocket" \ | |
and ("WS", request.path) in self.paths: | |
response = HTTPResponse(request, 200) | |
logger.debug("{} requesting websocket upgrade".format(request.remote_addr)) | |
if request.getHeader("Sec-Websocket-Key") == None: | |
logger.debug("{} missing Sec-Websocket-Key".format(request.remote_addr)) | |
response.status = 400 | |
await response.write(b"Missing Sec-Websocket-Key header") | |
else: | |
logger.debug("Initializing websocket handler for {}".format(request.remote_addr)) | |
websocket = HTTPWebsocket(request, response) | |
await self.paths[("WS", request.path)](request, websocket) | |
return #Closing of the socket is handled in HTTPWebsocket | |
else: | |
response = HTTPResponse(request, 404) | |
await response.write() | |
def register(self, method, path): | |
def _(func): | |
self.paths[(method, path)] = func | |
return _ | |
class HTTPServerConstructor: | |
async def handle_request(self, request): | |
try: | |
await getattr(self, "handle_"+request.method)(request) | |
except AttributeError: | |
response = HTTPResponse(request, 404) | |
await response.write() | |
class WSServer(HTTPServer): | |
async def handle_request(self, request): | |
if request.method == "GET" \ | |
and request.getHeader("connection", "").lower() == "upgrade" \ | |
and request.getHeader("upgrade", "").lower() == "websocket": | |
response = HTTPResponse(request, 200) | |
logger.debug("{} requesting websocket upgrade".format(request.remote_addr)) | |
if request.getHeader("Sec-Websocket-Key") == None: | |
logger.debug("{} missing Sec-Websocket-Key".format(request.remote_addr)) | |
response.status = 400 | |
await response.write(b"Missing Sec-Websocket-Key header") | |
else: | |
logger.debug("Initializing websocket handler for {}".format(request.remote_addr)) | |
websocket = HTTPWebsocket(request, response) | |
await self.handle_websocket(request, websocket) | |
return #Closing of the socket is handled in HTTPWebsocket | |
else: | |
response = HTTPResponse(request, 400) | |
await response.write(b"Websocket only server") | |
def test(): | |
#logging.basicConfig(level=logging.DEBUG) | |
server = HTTPServerPaths() | |
@server.register("GET", "/") | |
async def handle(request, response): | |
await response.write("Hello, world!") | |
@server.register("GET", "/websocket") | |
async def handle(request, response): | |
response.setHeader("Content-Type", "text/html") | |
await response.write("""<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Websocket test</title> | |
<script> | |
let socket, input, output; | |
function log(msg){ | |
if(output.value.length != 0) | |
output.value += "\\n"; | |
output.value += msg; | |
} | |
function connect(){ | |
socket = new WebSocket("ws://"+document.location.host+"/websocket/ws"); | |
socket.binaryType = "arraybuffer"; | |
socket.onopen = (event) => { | |
console.log("Open", event); | |
log("Connected"); | |
}; | |
socket.onclose = (event) => { | |
console.log("Close", event); | |
log(`Disconnected (${event.wasClean}): ${event.code} ${event.reason}`); | |
socket = null; | |
}; | |
socket.onerror = (event) => { | |
console.log("Error", event); | |
log(`Error`); | |
}; | |
socket.onmessage = (event) => { | |
console.log("Message", event); | |
if(typeof event.data === "string"){ | |
if(event.data == "StringTest"){ | |
socket.send("StringResponse"); | |
} | |
}else if(event.data instanceof ArrayBuffer){ | |
if(new TextDecoder('utf-8').decode(event.data) == "ByteTest"){ | |
socket.send(new TextEncoder('utf-8').encode("ByteResponse")); | |
} | |
} | |
log("Message "+event.data); | |
}; | |
} | |
function submit(){ | |
socket.send(input.value); | |
} | |
function reconnect(){ | |
if(socket){ | |
socket.close(); | |
socket.addEventListener("close", _=>{ | |
connect(); | |
}); | |
socket = null; | |
}else{ | |
connect(); | |
} | |
} | |
document.addEventListener("DOMContentLoaded", _=>{ | |
input = document.getElementById("input"); | |
output = document.getElementById("output"); | |
document.getElementById("submit").addEventListener("click", _=>{ | |
submit(); | |
}); | |
document.getElementById("reconnect").addEventListener("click", _=>{ | |
reconnect(); | |
}); | |
}); | |
</script> | |
</head> | |
<body> | |
<button id="reconnect">Connect</button><br/> | |
<textarea id="output" rows="25" cols="80" style="font-family:monospace"></textarea> | |
<br/> | |
<input type="text" id="input" size="80" style="font-family:monospace"></input> <button id="submit">Send</button> | |
<br/> | |
</body> | |
</html>""") | |
@server.register("WS", "/websocket/ws") | |
async def handle(request, websocket): | |
print("Websocket begin") | |
await websocket.accept() | |
await websocket.write("Connected") | |
await websocket.write("StringTest") | |
while True: | |
res = await websocket.recv() | |
if res.body == "StringResponse": | |
print("String test pass") | |
await websocket.write("String test pass") | |
break | |
while True: | |
await websocket.write(b"ByteTest") | |
if (await websocket.recv()).body == b"ByteResponse": | |
print("Byte test pass") | |
await websocket.write("Byte test pass") | |
break | |
await websocket.write("Disconnecting") | |
await websocket.close() | |
print("Websocket end") | |
asyncio.run(server.run(port = 8081)) | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment