Skip to content

Instantly share code, notes, and snippets.

@denik
Forked from progrium/upgrade_example.py
Created August 18, 2011 06:24
Show Gist options
  • Save denik/1153485 to your computer and use it in GitHub Desktop.
Save denik/1153485 to your computer and use it in GitHub Desktop.
Upgradable WSGI server websocket example
import gevent.pywsgi
from websocket import WebSocketUpgrader
class UpgradableWSGIHandler(gevent.pywsgi.WSGIHandler):
def handle_one_response(self):
connection_header = self.environ.get('HTTP_CONNECTION', '').lower()
if connection_header == 'upgrade' and self.server.upgrade_handler:
upgrade_header = self.environ.get('HTTP_UPGRADE', '').lower()
handler = self.server.upgrade_handler(upgrade_header, self.environ)
if handler:
handler(self.socket, self.environ)
self.rfile.close() # make sure WSGIHandler stops processing requests
return
gevent.pywsgi.WSGIHandler.handle_one_response(self)
class UpgradableWSGIServer(gevent.pywsgi.WSGIServer):
handler_class = UpgradableWSGIHandler
def __init__(self, listener, application=None, backlog=None, spawn='default', log='default', handler_class=None,
environ=None, upgrade_handler=None, **ssl_args):
gevent.pywsgi.WSGIServer.__init__(self, listener, application, backlog, spawn, log, handler_class,
environ, **ssl_args)
self.upgrade_handler = upgrade_handler
def wsgi_app(env, start_response):
start_response("200 OK", [])
return ["regular http"]
def upgrade(protocol, environ):
if protocol == 'websocket':
return WebSocketUpgrader(websocket_app)
def websocket_app(websocket):
echo = websocket.receive()
websocket.send(echo)
websocket.close()
server = UpgradableWSGIServer(('127.0.0.1', 9099), wsgi_app, upgrade_handler=upgrade)
server.serve_forever()
import re
import struct
from hashlib import md5
from socket import error
from gevent.pywsgi import WSGIHandler
from gevent.event import Event
from gevent.coros import Semaphore
# This class implements the Websocket protocol draft version as of May 23, 2010
# The version as of August 6, 2010 will be implementend once Firefox or
# Webkit-trunk support this version.
class WebSocketError(error):
pass
class WebSocketUpgrader(object):
""" Automatically upgrades the connection to websockets. """
def __init__(self, handler):
self.handler = handler
def __call__(self, socket, environ):
self.socket = socket
self.environ = environ
self.websocket = WebSocket(socket, environ)
headers = [
("Upgrade", "WebSocket"),
("Connection", "Upgrade"),
]
# Detect the Websocket protocol
if "HTTP_SEC_WEBSOCKET_KEY1" in environ:
version = 76
else:
version = 75
if version == 75:
headers.extend([
("WebSocket-Origin", self.websocket.origin),
("WebSocket-Protocol", self.websocket.protocol),
("WebSocket-Location", "ws://%s%s" % (self.environ.get('HTTP_HOST'), self.websocket.path)),
])
self.start_response("101 Web Socket Protocol Handshake", headers)
elif version == 76:
challenge = self._get_challenge()
headers.extend([
("Sec-WebSocket-Origin", self.websocket.origin),
("Sec-WebSocket-Protocol", self.websocket.protocol),
("Sec-WebSocket-Location", "ws://%s%s" % (self.environ.get('HTTP_HOST'), self.websocket.path)),
])
self.start_response("101 Web Socket Protocol Handshake", headers)
self.socket.sendall(challenge)
else:
raise WebSocketError("WebSocket version not supported")
self.handler(self.websocket)
self.websocket.finished.wait()
def start_response(self, status, headers):
towrite = []
towrite.append('HTTP/1.1 %s\r\n' % status)
for header in headers:
towrite.append("%s: %s\r\n" % header)
towrite.append("\r\n")
self.socket.sendall(''.join(towrite))
def _get_key_value(self, key_value):
key_number = int(re.sub("\\D", "", key_value))
spaces = re.subn(" ", "", key_value)[1]
if key_number % spaces != 0:
raise WebSocketError("key_number %d is not an intergral multiple of"
" spaces %d" % (key_number, spaces))
return key_number / spaces
def _get_challenge(self):
key1 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY1')
key2 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY2')
if not key1:
raise WebSocketError("SEC-WEBSOCKET-KEY1 header is missing")
if not key2:
raise WebSocketError("SEC-WEBSOCKET-KEY2 header is missing")
part1 = self._get_key_value(self.environ['HTTP_SEC_WEBSOCKET_KEY1'])
part2 = self._get_key_value(self.environ['HTTP_SEC_WEBSOCKET_KEY2'])
# This request should have 8 bytes of data in the body
key3 = self.environ.get('wsgi.input').rfile.read(8)
return md5(struct.pack("!II", part1, part2) + key3).digest()
class WebSocket(object):
def __init__(self, sock, environ):
self.rfile = sock.makefile('rb', -1)
self.socket = sock
self.origin = environ.get('HTTP_ORIGIN')
self.protocol = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'unknown')
self.path = environ.get('PATH_INFO')
self._writelock = Semaphore(1)
self.finished = Event()
def send(self, message):
if isinstance(message, unicode):
message = message.encode('utf-8')
elif isinstance(message, str):
message = unicode(message).encode('utf-8')
else:
raise Exception("Invalid message encoding")
with self._writelock:
self.socket.sendall("\x00" + message + "\xFF")
def detach(self):
self.socket = None
self.rfile = None
self.handler = None
def close(self):
# TODO implement graceful close with 0xFF frame
if self.socket is not None:
try:
self.socket.close()
except Exception:
pass
self.detach()
self.finished.set()
def _message_length(self):
# TODO: buildin security agains lengths greater than 2**31 or 2**32
length = 0
while True:
byte_str = self.rfile.read(1)
if not byte_str:
return 0
else:
byte = ord(byte_str)
if byte != 0x00:
length = length * 128 + (byte & 0x7f)
if (byte & 0x80) != 0x80:
break
return length
def _read_until(self):
bytes = []
while True:
byte = self.rfile.read(1)
if ord(byte) != 0xff:
bytes.append(byte)
else:
break
return ''.join(bytes)
def receive(self):
while self.socket is not None:
frame_str = self.rfile.read(1)
if not frame_str:
# Connection lost?
self.close()
break
else:
frame_type = ord(frame_str)
if (frame_type & 0x80) == 0x00: # most significant byte is not set
if frame_type == 0x00:
bytes = self._read_until()
return bytes.decode("utf-8", "replace")
else:
self.close()
elif (frame_type & 0x80) == 0x80: # most significant byte is set
# Read binary data (forward-compatibility)
if frame_type != 0xff:
self.close()
break
else:
length = self._message_length()
if length == 0:
self.close()
break
else:
self.rfile.read(length) # discard the bytes
else:
raise IOError("Reveiced an invalid message")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment