Created
October 31, 2018 09:36
-
-
Save adiroiban/a3dae64d14a56e8c0dbfea4bc6b4b827 to your computer and use it in GitHub Desktop.
Code to help writing tests for Twisted HTTP client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2014 Adi Roiban. | |
# See LICENSE for details. | |
""" | |
Helpers for testing code related to HTTP protocol. | |
""" | |
from __future__ import absolute_import | |
from httplib import HTTPConnection, HTTPSConnection | |
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer | |
from select import error as SelectError | |
from StringIO import StringIO | |
from threading import Thread | |
import codecs | |
import errno | |
from httplib import BadStatusLine | |
import socket | |
import ssl | |
import threading | |
import time | |
import urllib | |
from twisted.internet import defer, interfaces as internet_interfaces | |
from twisted.python.failure import Failure | |
from twisted.web import ( | |
http, | |
server as web_server, | |
) | |
from twisted.web.client import ResponseDone | |
from twisted.web.http_headers import Headers | |
from twisted.web.iweb import IBodyProducer | |
from zope.interface import implementer | |
# Marker used to signal that the expected request will have a streamed body. | |
STREAMED_REQUEST = object() | |
STREAMED_RESPONSE = object() | |
@implementer(internet_interfaces.IPushProducer) | |
class InMemoryBodyDeliverer(object): | |
""" | |
A wrapper to externalized the stubbing of data received by the response. | |
""" | |
def __init__(self, consumer): | |
self._consumer = consumer | |
self._consumer.transport = self | |
self.stopped = False | |
self.paused = False | |
def pauseProducing(self): | |
self.paused = True | |
def resumeProducing(self): | |
self.paused = False | |
def stopProducing(self): | |
self.stopped = True | |
def sendData(self, data): | |
""" | |
Simulate a data received from the response. | |
""" | |
if self.stopped: | |
raise AssertionError('Data received after producer was stopped.') | |
if self.paused: | |
raise AssertionError('Data received while paused.') | |
self._consumer.dataReceived(data) | |
def stopSending(self): | |
""" | |
Simulate that the body delivered has sent all the data, as signaled | |
by closing the connection. | |
""" | |
self._consumer.connectionLost(Failure(ResponseDone())) | |
class InMemoryResponse(object): | |
""" | |
Simple response from in memory agent request. | |
""" | |
def __init__( | |
self, method='GET', code=404, phrase='Phrase not set', | |
request_body=None, request_json=None, request_headers=None, | |
response_body=b'', response_headers={}, | |
): | |
self.method = method | |
self.code = code | |
self.phrase = phrase | |
self.request_headers = Headers(request_headers) | |
if request_body is not None: | |
request_body = request_body | |
elif request_json is not None: | |
request_body = json.dumps(request_json) | |
else: | |
request_body = '' | |
self.request_body = request_body | |
self.response_body = response_body | |
self.headers = Headers(response_headers) | |
self._stream_request = False | |
self._streamed_body = None | |
@property | |
def streamed_body(self): | |
""" | |
The content of the requested as it was streamed. | |
""" | |
return self._streamed_body | |
def write(self, data): | |
""" | |
Called when received a streamed request. | |
""" | |
if self.request_body != STREAMED_REQUEST: | |
raise AssertionError('Write while not in stream mode.') | |
if not isinstance(data, bytes): | |
raise AssertionError('Write operation for non bytes.') | |
if self._streamed_body is None: | |
self._streamed_body = data | |
else: | |
self._streamed_body += data | |
def deliverBody(self, consumer_protocol): | |
""" | |
See: twisted.web.client.Response.deliverBody. | |
""" | |
producer = InMemoryBodyDeliverer(consumer_protocol) | |
if self.response_body == STREAMED_RESPONSE: | |
self.body_producer = producer | |
return | |
consumer_protocol.dataReceived(self.response_body) | |
consumer_protocol.connectionLost(Failure(ResponseDone())) | |
class InMemoryPersistentAgent(object): | |
""" | |
Simple implementation of an HTTP Persistent Agent. | |
It will not emit any event and does not checks SSL/TLS. | |
""" | |
def __init__(self, event_emitter=None, ssl_context=None, credentials=None): | |
self._connections = {} | |
self._credentials = None | |
self.responses = {} | |
def _getResponse(self, method, url, headers, body): | |
""" | |
Return the arranged response for our rigged agent. | |
""" | |
if body is None: | |
body = '' | |
self._cacheURL(url) | |
headers = Headers(headers) | |
try: | |
response = self.responses[url].pop(0) | |
except (KeyError, IndexError): | |
raise AssertionError( | |
'URL not defined %s for %s' % (method, url)) | |
if isinstance(response, Exception): | |
return defer.fail(response) | |
if isinstance(response, defer.Deferred): | |
deferred = response | |
else: | |
deferred = defer.succeed(response) | |
def cb_check_response(response): | |
""" | |
Called when we got a response. | |
""" | |
if response.method != method: | |
raise AssertionError( | |
'No %s request was defined for %s' % (method, url)) | |
if IBodyProducer.providedBy(body): | |
# We have a streamed request, so the body is checked later. | |
if response.request_body != STREAMED_REQUEST: | |
raise AssertionError( | |
'Unexpected streamed request. Expecting: %s' % ( | |
response.request_body,)) | |
done_deferred = body.startProducing(response) | |
response.start_producing_deferred = done_deferred | |
else: | |
if response.request_body == STREAMED_REQUEST: | |
raise AssertionError( | |
'Expecting streamed request. Got: %r' % (body,)) | |
# Just compare the whole body. | |
if response.request_body != body: | |
raise AssertionError( | |
'Invalid request body:\ngot\n%s\nexpecting\n%s' % ( | |
body, response.request_body)) | |
if response.request_headers != headers: | |
raise AssertionError( | |
'Invalid headers:\ngot\n%s\nexpecting\n%s' % ( | |
headers, response.request_headers)) | |
# All good. | |
return response | |
deferred.addCallback(cb_check_response) | |
return deferred | |
def _cacheURL(self, url): | |
""" | |
Called to mark that URL is cached. | |
""" | |
# For connection caching we cache only based on host. | |
parts = url.split('/', 3) | |
if len(parts) > 3: | |
host = '/'.join(parts[:-1]) | |
else: | |
host = url | |
self._connections[host] = b'cached' | |
def postJSON(self, url, body, headers=None): | |
""" | |
See: PersistentAgent. | |
""" | |
if headers is None: | |
headers = {} | |
# Make sure content type is always JSON. | |
headers['Content-Type'] = [b'application/json; charset=utf-8'] | |
return self.post(url, headers, json.dumps(body).encode('utf-8')) | |
def post(self, url, headers, body): | |
""" | |
See: PersistentAgent. | |
""" | |
return self._getResponse(b'POST', url, headers, body) | |
def put(self, url, headers, body): | |
""" | |
See: PersistentAgent. | |
""" | |
return self._getResponse(b'PUT', url, headers, body) | |
def get(self, url, headers): | |
""" | |
See: PersistentAgent. | |
""" | |
return self._getResponse(b'GET', url, headers, body=None) | |
def delete(self, url, headers): | |
""" | |
See: PersistentAgent. | |
""" | |
return self._getResponse(b'DELETE', url, headers, body=None) | |
def propfind(self, url, headers): | |
""" | |
See: PersistentAgent. | |
""" | |
return self._getResponse(b'PROPFIND', url, headers, body=None) | |
def mkcol(self, url, headers): | |
""" | |
See: PersistentAgent. | |
""" | |
return self._getResponse(b'MKCOL', url, headers, body=None) | |
def head(self, url, headers): | |
""" | |
See: PersistentAgent. | |
""" | |
return self._getResponse(b'HEAD', url, headers, body=None) | |
def move(self, url, headers): | |
""" | |
See: PersistentAgent. | |
""" | |
return self._getResponse(b'MOVE', url, headers, body=None) | |
def readBody(self, response): | |
""" | |
Return a deferred which fires with response body. | |
""" | |
return defer.succeed(response.response_body) | |
def closePersistentConnections(self): | |
""" | |
See: PersistentAgent. | |
""" | |
self._connections = {} | |
@property | |
def connections(self): | |
""" | |
Return the list of persistent connections. | |
""" | |
return self._connections | |
class DummyHTTPChannel(object): | |
port = 80 | |
disconnected = False | |
def __init__(self, site=None, peer=None, host=None, resource=None): | |
from chevah.server.testing import mk | |
self.written = StringIO() | |
self.producers = [] | |
if peer is None: | |
peer = mk.makeIPv4Address(host='42.0.0.1', port=1234) | |
if host is None: | |
host = mk.makeIPv4Address(host='10.0.0.1', port=self.port) | |
if site is None: | |
site = web_server.Site(resource) | |
self.site = site | |
self._peer = peer | |
self._host = host | |
def getPeer(self): | |
return self._peer | |
def write(self, data): | |
if not isinstance(data, bytes): | |
raise AssertionError('Non bytes write request.') | |
self.written.write(bytes) | |
def writeSequence(self, iovec): | |
map(self.write, iovec) | |
def getHost(self): | |
return self._host | |
def registerProducer(self, producer, streaming): | |
self.producers.append((producer, streaming)) | |
def loseConnection(self): | |
self.disconnected = True | |
def requestDone(self, request): | |
pass | |
@implementer(internet_interfaces.ISSLTransport) | |
class DummyHTTPSChannel(DummyHTTPChannel): | |
port = 443 | |
class DummyWebRequest(Request): | |
""" | |
A dummy Twisted Web Request used in tests. | |
uri - url encoded | |
postpath, prepath -, url decoded, utf-8 encoded (not Unicode) | |
`prepath` is what was already traversed. | |
`postpath` is what is left to be traversed. | |
Set 'prepath' to [] to prevent automatic traversal. | |
""" | |
def __init__( | |
self, | |
postpath=None, prepath=None, session=None, resource=None, | |
data=None, peer=None, site=None, | |
uri=None, clientproto=None, method=b'GET', secured=False, | |
path=None, host=None, channel=None | |
): | |
from chevah.server.testing import mk | |
if peer is None: | |
peer = mk.makeIPv4Address(host='42.0.0.1') | |
if channel is None: | |
channel = DummyHTTPChannel(peer=peer, host=host, site=site) | |
self.channel = channel | |
self.site = channel.site | |
self.content = StringIO() | |
self.written = [] | |
if data: | |
self.content.write(data) | |
self.content.seek(0) | |
# Full URL including arguments | |
if uri is None: | |
uri = b'/uri-not-defined' | |
elif postpath is not None: | |
raise AssertionError('You can not define both URI and postpath.') | |
if isinstance(uri, unicode): | |
# For testing we accept Unicode URI but internally we have to | |
# encode it. | |
self.uri = urllib.quote(uri.encode('utf-8')) | |
else: | |
self.uri = uri | |
# HTTP URL arguments POST or GET | |
self.args = {} | |
# HTTP URL without arguments | |
self.path = path | |
if prepath is None: | |
if self.uri: | |
prepath = self.uri.split('/') | |
else: | |
prepath = [] | |
self.prepath = [] | |
for part in prepath: | |
if isinstance(part, unicode): | |
part = part.encode('utf-8') | |
self.prepath.append(part) | |
if postpath is None: | |
# FIXME:4568: | |
# Here we should not ignore the first segment of the URI. | |
# The code started only for RESTFolder which is accessed as /home/ | |
# and then it ignores the home. | |
self.postpath = [p.encode('utf-8') for p in uri.split('/')[1:]] | |
else: | |
self.postpath = [] | |
for part in postpath: | |
if isinstance(part, unicode): | |
part = part.encode('utf-8') | |
self.postpath.append(part) | |
self.uri = urllib.quote( | |
'/%s' % '/'.join(self.postpath + self.prepath)) | |
if isinstance(self.uri, unicode): | |
raise AssertionError('URI should be URL encoded.') | |
self.sitepath = [] | |
self.client = peer | |
self.secured = secured | |
if clientproto is None: | |
clientproto = 'HTTP/1.0' | |
self.clientproto = clientproto | |
self.method = method.upper().encode('ascii') | |
self.session = None | |
self.protoSession = session or web_server.Session(0, self) | |
self._code = http.OK | |
self._code_message = b'OK' | |
self.responseHeaders = Headers() | |
self.requestHeaders = Headers() | |
# This should be called after we have defined the request headers. | |
if host is None: | |
host = b'dummy.host.tld' | |
self.setRequestHeader(b'host', host) | |
self.received_cookies = {} | |
self.cookies = [] # outgoing cookies | |
# Finish notifications | |
self.notifications = [] | |
self.finished = 0 | |
self.queued = False | |
def __repr__(self): | |
return ( | |
'DummyWebRequest for "%(uri)s", code: %(code)s\n' | |
'response content: "%(response_content)s"\n' | |
'response headers: "%(response_headers)s' % ({ | |
'uri': self.uri, | |
'code': self.code, | |
'response_content': self.test_response_content, | |
'response_headers': dict( | |
self.responseHeaders.getAllRawHeaders()), | |
}) | |
) | |
@property | |
def code(self): | |
return self._code | |
@property | |
def code_message(self): | |
return self._code_message | |
@property | |
def test_response_content(self): | |
""" | |
Return the data written to the request during tests. | |
""" | |
return ''.join(self.written) | |
@property | |
def written_content(self): | |
""" | |
All the data written to the request. | |
""" | |
return b''.join(self.written) | |
def _checkContentNotWritten(self, method_name): | |
""" | |
Check that content was not sent to remote peer. | |
`method_name` is there to help pinpoint where things went wrong. | |
""" | |
assert not self.written, ( | |
"%s cannot be called after data has been written: %s." % ( | |
method_name, "@@@@".join(self.written))) | |
def getSession(self, sessionInterface=None): | |
""" | |
Return server sided stored session object. | |
""" | |
if not self.session: | |
self._checkContentNotWritten('getSession') | |
return super(DummyWebRequest, self).getSession(sessionInterface) | |
def write(self, data): | |
if not isinstance(data, bytes): | |
raise AssertionError("Non bytes write request.") | |
self.written.append(data) | |
def isSecure(self): | |
return self.secured | |
def getHeader(self, name): | |
""" | |
Public method for a Request. | |
""" | |
value = self.requestHeaders.getRawHeaders(name) | |
if value is not None: | |
return value[-1] | |
def setHeader(self, name, value): | |
""" | |
Public method for a Request. | |
""" | |
self.responseHeaders.setRawHeaders(name, [value]) | |
def getRequestHeader(self, name): | |
""" | |
Testing method to get request headers. | |
This is here so that we can have clear/explicit tests. | |
""" | |
return self.getHeader(name) | |
def setRequestHeader(self, name, value): | |
""" | |
Testing method to set request headers. | |
This is here so that we can have clear/explicit tests. | |
""" | |
self.requestHeaders.setRawHeaders(name.lower(), [value]) | |
def getResponseHeader(self, name): | |
""" | |
Testing method to get response headers. | |
This is here so that we can have clear/explicit tests. | |
""" | |
value = self.responseHeaders.getRawHeaders(name) | |
if value is not None: | |
return value[-1] | |
def setResponseHeader(self, name, value): | |
""" | |
Testing method to set response headers. | |
This is here so that we can have clear/explicit tests. | |
""" | |
self.setHeader(self, name, value) | |
def setLastModified(self, when): | |
""" | |
See: twisted.web.http.Request. | |
""" | |
self._checkContentNotWritten('setLastModified') | |
return super(DummyWebRequest, self).setLastModified(when) | |
def setETag(self, tag): | |
""" | |
See: twisted.web.http.Request. | |
Just checks that method is not called after response was sent. | |
""" | |
self._checkContentNotWritten('setETag') | |
def getCookie(self, key): | |
""" | |
Get an incoming HTTP cookie. | |
""" | |
return self.received_cookies.get(key, None) | |
def setReceivedCookie(self, key, value): | |
""" | |
Set cookie as it were received in the request. | |
""" | |
self.received_cookies[key] = value | |
def redirect(self, url): | |
""" | |
Utility function that does a redirect. | |
The request should have finish() called after this. | |
""" | |
self.setResponseCode(http.FOUND) | |
self.setHeader("location", url) | |
def registerProducer(self, prod, s): | |
""" | |
See: twisted.web.http.Request. | |
""" | |
self.go = 1 | |
while self.go: | |
prod.resumeProducing() | |
def unregisterProducer(self): | |
""" | |
See: twisted.web.http.Request. | |
""" | |
self.go = 0 | |
def processingFailed(self, reason): | |
""" | |
Errback and L{Deferreds} waiting for finish notification. | |
""" | |
if self.notifications is not None: | |
observers = self.notifications | |
self.notifications = None | |
for obs in observers: | |
obs.errback(reason) | |
def setResponseCode(self, code, message=None): | |
""" | |
Set the HTTP status response code, but takes care that this is called | |
before any data is written. | |
""" | |
self._checkContentNotWritten('setResponseCode') | |
# Code and code_message are read-only for our test purpose. | |
self._code = code | |
self._code_message = message | |
def render(self, resource): | |
""" | |
Render the given resource as a response to this request. | |
This implementation only handles a few of the most common behaviors of | |
resources. It can handle a render method that returns a string or | |
C{NOT_DONE_YET}. It doesn't know anything about the semantics of | |
request methods (eg HEAD) nor how to set any particular headers. | |
Basically, it's largely broken, but sufficient for some tests at | |
least. | |
It should B{not} be expanded to do all the same stuff L{Request} does. | |
Instead, L{DummyRequest} should be phased out and L{Request} (or some | |
other real code factored in a different way) used. | |
""" | |
result = resource.render(self) | |
if result is web_server.NOT_DONE_YET: | |
return | |
self.write(result) | |
self.finish() | |
def setRequestContent(self, data): | |
""" | |
Set body content of the request. | |
""" | |
self.content.write(data) | |
self.content.seek(0) | |
class _StoppableHTTPServer(HTTPServer): | |
""" | |
Single connection HTTP server designed to respond to HTTP requests in | |
functional tests. | |
""" | |
server_version = 'ChevahTesting/0.1' | |
stopped = False | |
# Current connection served by the server. | |
active_connection = None | |
def serve_forever(self): | |
""" | |
Handle one request at a time until stopped. | |
""" | |
self.stopped = False | |
self.active_connection = None | |
while not self.stopped: | |
try: | |
self.handle_request() | |
except SelectError as e: | |
# See Python http://bugs.python.org/issue7978 | |
if e.args[0] == errno.EINTR: | |
continue | |
raise | |
except socket.error as error: | |
if error.errno == errno.EBADF: | |
self.stopped = True | |
else: | |
break | |
class _ThreadedHTTPServer(Thread): | |
""" | |
HTTP or HTTPS Server that runs in a thread. | |
This is actual a threaded wrapper around an HTTP server. | |
""" | |
TIMEOUT = 1 | |
def __init__( | |
self, handler_class, cond, | |
responses=None, | |
ip='127.0.0.1', port=0, | |
server_certificate=None, | |
debug=False, | |
): | |
Thread.__init__(self) | |
self.ready = False | |
self.cond = cond | |
self._ip = ip | |
self._port = port | |
self._handler_class = handler_class | |
self._server_certificate = server_certificate | |
def run(self): | |
self.cond.acquire() | |
timeout = 0 | |
self.httpd = None | |
while self.httpd is None: | |
try: | |
self.httpd = _StoppableHTTPServer( | |
(self._ip, self._port), self._handler_class) | |
if self._server_certificate: | |
self.httpd.socket = ssl.wrap_socket( | |
self.httpd.socket, | |
certfile=self._server_certificate, | |
server_side=True, | |
) | |
except Exception as e: | |
# I have no idea why this code works. | |
# It is a copy paste from: | |
# http://www.ianlewis.org/en/testing-using-mocked-server | |
import errno | |
import time | |
if (isinstance(e, socket.error) and | |
errno.errorcode[e.args[0]] == 'EADDRINUSE' and | |
timeout < self.TIMEOUT): | |
timeout += 1 | |
time.sleep(1) | |
else: | |
self.cond.notifyAll() | |
self.cond.release() | |
self.ready = True | |
raise e | |
self.ready = True | |
if self.cond: | |
self.cond.notifyAll() | |
self.cond.release() | |
# Start the actual HTTP server. | |
self.httpd.serve_forever() | |
class HTTPServerContext(object): | |
""" | |
A context manager which runs a HTTP or HTTPS server for testing simple | |
HTTP requests. | |
After the server is started the ip and port are available in the | |
context management instance. | |
response = ResponseDefinition(url='/hello.html', response_content='Hello!) | |
with HTTPServerContext([response]) as httpd: | |
print 'Listening at %s:%d' % (httpd.id, httpd.port) | |
self.assertEqual('Hello!', your_get()) | |
responses = ResponseDefinition( | |
url='/hello.php', request='user=John', | |
response_content='Hello John!, response_code=202) | |
with HTTPServerContext([response]) as httpd: | |
self.assertEqual( | |
'Hello John!', | |
get_you_post(url='hello.php', data='user=John')) | |
""" | |
def __init__( | |
self, responses=None, ip='127.0.0.1', port=0, | |
version='HTTP/1.1', cert=None, debug=False, | |
auth=None, | |
): | |
""" | |
Initialize a new HTTPServerContext. | |
* ip - IP to listen. Leave empty to listen to any interface. | |
* port - Port to listen. Leave 0 to pick a random port. | |
* server_version - HTTP version used by server. | |
* responses - A list of ResponseDefinition defining the behavior of | |
this server. | |
* cert - certificate to be used by the HTTPS server. | |
""" | |
self._previous_valid_responses = _DefinedRequestHandler.valid_responses | |
self._previous_first_client = _DefinedRequestHandler.first_client | |
self._previous_authorization = _DefinedRequestHandler.authorization | |
# Since we can not pass an instance of _DefinedRequestHandler | |
# we do on the fly patching here. | |
# Servers might be nested, so a debug will trigger any server. | |
# Also, this has the side effect, that once a debug is enabled in | |
# one tests, it will be enabled in all the tests after it, so we | |
# keep the state at start to see if we should revert. | |
self._debug = debug | |
self._previous_debug = _DefinedRequestHandler.debug | |
_DefinedRequestHandler.debug = _DefinedRequestHandler.debug or debug | |
if responses is None: | |
_DefinedRequestHandler.valid_responses = [] | |
else: | |
_DefinedRequestHandler.valid_responses = responses | |
_DefinedRequestHandler.authorization = auth | |
_DefinedRequestHandler.protocol_version = version | |
self.cond = threading.Condition() | |
self._server_certificate = cert | |
self.server = _ThreadedHTTPServer( | |
handler_class=_DefinedRequestHandler, | |
cond=self.cond, | |
ip=ip, | |
port=port, | |
server_certificate=cert, | |
) | |
def __enter__(self): | |
self.cond.acquire() | |
self.server.start() | |
# Wait until the server is ready. | |
while not self.server.ready: | |
self.cond.wait() | |
self.cond.release() | |
# Even if the thread ready, it might still need some time | |
# to be ready. | |
time.sleep(0.03) | |
return self | |
def __exit__(self, exc_type, exc_value, tb): | |
# _DefinedRequestHandler initialization is outside of control so | |
# we share state as class members. To free memory we need to clean it. | |
_DefinedRequestHandler.cleanGlobals() | |
if self._debug and not self._previous_debug: | |
_DefinedRequestHandler.debug = False | |
self.stopServer() | |
self.server.join(1) | |
# The reverting should be done after the thread is closed so that | |
# we don't have things inside the thread setting the class values. | |
_DefinedRequestHandler.valid_responses = self._previous_valid_responses | |
_DefinedRequestHandler.first_client = self._previous_first_client | |
_DefinedRequestHandler.authorization = self._previous_authorization | |
if self.server.isAlive(): | |
raise AssertionError('Server still running') | |
return False | |
@property | |
def port(self): | |
return self.server.httpd.server_address[1] | |
@property | |
def ip(self): | |
return self.server.httpd.server_address[0] | |
def stopServer(self): | |
connection = self.server.httpd.active_connection | |
try: | |
if connection and connection.rfile._sock: | |
# Stop waiting for data from persistent connection. | |
self.server.httpd.stopped = True | |
sock = connection.rfile._sock | |
sock.shutdown(socket.SHUT_RDWR) | |
sock.close() | |
else: | |
# Stop waiting for data from new connection. | |
# This is done by sending a special QUIT request without | |
# waiting for data. | |
address = '%s:%d' % (self.ip, self.port) | |
if self._server_certificate: | |
try: | |
context = ssl._create_unverified_context() | |
except Exception: | |
# Python < 2.7.9 uses the old HTTPS code. | |
context = None | |
if context: | |
# Use an HTTPS connection since we are in HTTPS mode. | |
conn = HTTPSConnection( | |
address, context=context, timeout=5) | |
else: | |
conn = HTTPSConnection(address, timeout=5) | |
else: | |
conn = HTTPConnection(address, timeout=5) | |
conn.request("QUIT", "/") | |
conn.getresponse() | |
except (socket.error, BadStatusLine): | |
# Ignore socket errors at shutdown as the connection | |
# might be already closed. | |
pass | |
self.server.httpd.server_close() | |
class _DefinedRequestHandler(BaseHTTPRequestHandler, object): | |
""" | |
A request handler which act based on pre-defined responses. | |
This should only be used for test together with HTTPServerContext. | |
""" | |
valid_responses = [] | |
debug = False | |
# Keep a record of first client which connects to the request | |
# series to have a better check for persisted connections. | |
first_client = None | |
# Value of the authorization header required for requests. | |
authorization = None | |
def __init__(self, request, client_address, server): | |
if self.debug: | |
print('New connection %s.' % (client_address,)) | |
# Register current connection on server. | |
server.active_connection = self | |
try: | |
super(_DefinedRequestHandler, self).__init__( | |
request, client_address, server) | |
except socket.error: | |
pass | |
server.active_connection = None | |
@classmethod | |
def cleanGlobals(cls): | |
""" | |
Clean all class methods used to share info between different requests. | |
""" | |
cls.valid_responses = [] | |
cls.first_client = None | |
def log_message(self, *args): | |
""" | |
We don't want any logs from the server. | |
""" | |
pass | |
def do_QUIT(self): | |
""" | |
Called by HTTPServerContext to trigger server stop. | |
""" | |
self.server.stopped = True | |
self._debug('QUIT') | |
self.send_response(200) | |
self.end_headers() | |
# Force closing the connection. | |
self.close_connection = 1 | |
def do_GET(self): | |
self._handleRequest() | |
def do_HEAD(self): | |
self._handleRequest() | |
def do_POST(self): | |
self._handleRequest() | |
def do_PUT(self): | |
self._handleRequest() | |
def do_DELETE(self): | |
self._handleRequest() | |
def do_PROPFIND(self): | |
self._handleRequest() | |
def _handleRequest(self): | |
""" | |
Check if we can handle the request and send response. | |
""" | |
if not self.first_client: | |
# Looks like this is the first request so we save the client | |
# address to compare it later. | |
self._debug('First: %r' % (self.client_address,)) | |
self.__class__.first_client = self.client_address | |
response = self._matchResponse() | |
if response: | |
self._debug(response) | |
self._sendResponse(response) | |
self._debug( | |
'After: Close-connection: %s' % (self.close_connection,)) | |
return | |
self.send_error(404) | |
def _matchResponse(self): | |
""" | |
Return the ResponseDefinition for the current request. | |
Return `None` if no match was found. | |
""" | |
length = int(self.headers.getheader('content-length', 0)) | |
content = self.rfile.read(length) | |
self._debug(content) | |
auth_request = self.headers.getheader('authorization', None) | |
if self.authorization != auth_request: # noqa:cover | |
error = 'Request not authorized. Expect %s got %s' % ( | |
self.authorization, auth_request) | |
self._debug(error) | |
return ResponseDefinition( | |
response_code=401, | |
response_message='Not authorized', | |
response_content=error, | |
) | |
any_matching = None | |
for response in self.__class__.valid_responses: | |
if self.path != response.url or self.command != response.method: | |
continue | |
# For GET we have the response and don't check further. | |
if self.command == 'GET': | |
return response | |
# We don't need exact content matching | |
if response.request is ResponseDefinition.ANY: | |
self._debug('ANY matching.') | |
# Keep the response, but see if we have a more specific | |
# match. | |
any_matching = response | |
continue | |
if content == response.request: | |
return response | |
return any_matching | |
def _debug(self, message=''): | |
""" | |
Print to stdout a debug message. | |
""" | |
if not self.debug: | |
return | |
print('\nGot %s:%s - %s\n' % ( | |
self.command, self.path, message)) | |
def _sendResponse(self, response): | |
""" | |
Send response to client. | |
""" | |
connection_header = self.headers.getheader('connection') | |
if connection_header: | |
connection_header = connection_header.lower() | |
if self.protocol_version == 'HTTP/1.1': | |
# For HTTP/1.1 connections are persistent by default. | |
if not connection_header: | |
connection_header = 'keep-alive' | |
else: | |
# For HTTP/1.0 connections are not persistent by default. | |
if not connection_header: | |
connection_header = 'close' | |
if response.persistent is None: | |
# Ignore persistent flag. | |
pass | |
elif response.persistent: | |
if connection_header == 'close': | |
self.send_error(400, 'Headers do not persist the connection') | |
if self.first_client != self.client_address: | |
self.send_error( | |
400, | |
'Persistent connection not reused. First %r. Now %r' % ( | |
self.first_client, self.client_address)) | |
else: | |
if connection_header == 'keep-alive': | |
self.send_error(400, 'Connection was persistent') | |
if isinstance(response.response_code, Exception): | |
# Just close the connection without a valid HTTP response. | |
self.wfile.write(response.response_code.message) | |
self.close_connection = 1 | |
return | |
self.send_response( | |
response.response_code, response.response_message) | |
self.send_header("Content-Type", response.content_type) | |
if response.response_length: | |
self.send_header("Content-Length", response.response_length) | |
self.end_headers() | |
self.wfile.write(response.test_response_content) | |
if not response.response_persistent: | |
# Force closing the connection as requested | |
# by response. | |
self.close_connection = 1 | |
class ResponseDefinition(object): | |
""" | |
A class encapsulating the required data for configuring a response | |
generated by the HTTPServerContext. | |
It contains the following data: | |
* url - url that will trigger this response | |
* request - Content of the request that will trigger the response once | |
the url is matched | |
- Set to None to trigger on any content. | |
* response_content - content of the response | |
* response_code - HTTP code of the response or an Exception(). | |
If the Exception has any text, it is written to the response as raw | |
data. | |
* response_message - Message sent together with HTTP code. | |
* content_type - Content type of the HTTP response | |
* response_length - Length of the response body content. | |
`None` to calculate automatically the length. | |
`` (empty string) to ignore content-length header. | |
* persistent: whether the request should persist the connection. | |
Set to None to ignore persistent checking. | |
""" | |
ANY = object() | |
def __init__( | |
self, url='', request='', method='GET', | |
response_content=b'', response_code=200, response_message=None, | |
content_type=b'text/html', response_length=None, | |
persistent=True, response_persistent=None, | |
): | |
self.url = url | |
self.method = method | |
self.request = request | |
if not isinstance(response_content, bytes): | |
response_content = codecs.encode(response_content, 'utf-8') | |
self.test_response_content = response_content | |
self.response_code = response_code | |
self.response_message = response_message | |
self.content_type = content_type | |
if response_length is None: | |
response_length = len(self.test_response_content) | |
self.response_length = str(response_length) | |
self.persistent = persistent | |
if response_persistent is None: | |
response_persistent = persistent | |
self.response_persistent = response_persistent | |
def __repr__(self): | |
return 'ResponseDefinition:%s:%s:%s %s:pers-%s' % ( | |
self.url, | |
self.method, | |
self.response_code, self.response_message, | |
self.persistent, | |
) | |
def updateResponseContent(self, content): | |
""" | |
Will update the content returned to the server. | |
""" | |
if not isinstance(content, bytes): | |
content = codecs.encode(content, 'utf-8') | |
self.test_response_content = content | |
response_length = len(self.test_response_content) | |
self.response_length = str(response_length) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment