Skip to content

Instantly share code, notes, and snippets.

@adiroiban
Created October 31, 2018 09:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adiroiban/a3dae64d14a56e8c0dbfea4bc6b4b827 to your computer and use it in GitHub Desktop.
Save adiroiban/a3dae64d14a56e8c0dbfea4bc6b4b827 to your computer and use it in GitHub Desktop.
Code to help writing tests for Twisted HTTP client
# Copyright (c) 2014 Adi Roiban.
# See LICENSE for details.
"""
Helpers for testing code related to HTTP protocol.
"""
from __future__ import absolute_import
from httplib import HTTPConnection, HTTPSConnection
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
from select import error as SelectError
from StringIO import StringIO
from threading import Thread
import codecs
import errno
from httplib import BadStatusLine
import socket
import ssl
import threading
import time
import urllib
from twisted.internet import defer, interfaces as internet_interfaces
from twisted.python.failure import Failure
from twisted.web import (
http,
server as web_server,
)
from twisted.web.client import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer
from zope.interface import implementer
# Marker used to signal that the expected request will have a streamed body.
STREAMED_REQUEST = object()
STREAMED_RESPONSE = object()
@implementer(internet_interfaces.IPushProducer)
class InMemoryBodyDeliverer(object):
"""
A wrapper to externalized the stubbing of data received by the response.
"""
def __init__(self, consumer):
self._consumer = consumer
self._consumer.transport = self
self.stopped = False
self.paused = False
def pauseProducing(self):
self.paused = True
def resumeProducing(self):
self.paused = False
def stopProducing(self):
self.stopped = True
def sendData(self, data):
"""
Simulate a data received from the response.
"""
if self.stopped:
raise AssertionError('Data received after producer was stopped.')
if self.paused:
raise AssertionError('Data received while paused.')
self._consumer.dataReceived(data)
def stopSending(self):
"""
Simulate that the body delivered has sent all the data, as signaled
by closing the connection.
"""
self._consumer.connectionLost(Failure(ResponseDone()))
class InMemoryResponse(object):
"""
Simple response from in memory agent request.
"""
def __init__(
self, method='GET', code=404, phrase='Phrase not set',
request_body=None, request_json=None, request_headers=None,
response_body=b'', response_headers={},
):
self.method = method
self.code = code
self.phrase = phrase
self.request_headers = Headers(request_headers)
if request_body is not None:
request_body = request_body
elif request_json is not None:
request_body = json.dumps(request_json)
else:
request_body = ''
self.request_body = request_body
self.response_body = response_body
self.headers = Headers(response_headers)
self._stream_request = False
self._streamed_body = None
@property
def streamed_body(self):
"""
The content of the requested as it was streamed.
"""
return self._streamed_body
def write(self, data):
"""
Called when received a streamed request.
"""
if self.request_body != STREAMED_REQUEST:
raise AssertionError('Write while not in stream mode.')
if not isinstance(data, bytes):
raise AssertionError('Write operation for non bytes.')
if self._streamed_body is None:
self._streamed_body = data
else:
self._streamed_body += data
def deliverBody(self, consumer_protocol):
"""
See: twisted.web.client.Response.deliverBody.
"""
producer = InMemoryBodyDeliverer(consumer_protocol)
if self.response_body == STREAMED_RESPONSE:
self.body_producer = producer
return
consumer_protocol.dataReceived(self.response_body)
consumer_protocol.connectionLost(Failure(ResponseDone()))
class InMemoryPersistentAgent(object):
"""
Simple implementation of an HTTP Persistent Agent.
It will not emit any event and does not checks SSL/TLS.
"""
def __init__(self, event_emitter=None, ssl_context=None, credentials=None):
self._connections = {}
self._credentials = None
self.responses = {}
def _getResponse(self, method, url, headers, body):
"""
Return the arranged response for our rigged agent.
"""
if body is None:
body = ''
self._cacheURL(url)
headers = Headers(headers)
try:
response = self.responses[url].pop(0)
except (KeyError, IndexError):
raise AssertionError(
'URL not defined %s for %s' % (method, url))
if isinstance(response, Exception):
return defer.fail(response)
if isinstance(response, defer.Deferred):
deferred = response
else:
deferred = defer.succeed(response)
def cb_check_response(response):
"""
Called when we got a response.
"""
if response.method != method:
raise AssertionError(
'No %s request was defined for %s' % (method, url))
if IBodyProducer.providedBy(body):
# We have a streamed request, so the body is checked later.
if response.request_body != STREAMED_REQUEST:
raise AssertionError(
'Unexpected streamed request. Expecting: %s' % (
response.request_body,))
done_deferred = body.startProducing(response)
response.start_producing_deferred = done_deferred
else:
if response.request_body == STREAMED_REQUEST:
raise AssertionError(
'Expecting streamed request. Got: %r' % (body,))
# Just compare the whole body.
if response.request_body != body:
raise AssertionError(
'Invalid request body:\ngot\n%s\nexpecting\n%s' % (
body, response.request_body))
if response.request_headers != headers:
raise AssertionError(
'Invalid headers:\ngot\n%s\nexpecting\n%s' % (
headers, response.request_headers))
# All good.
return response
deferred.addCallback(cb_check_response)
return deferred
def _cacheURL(self, url):
"""
Called to mark that URL is cached.
"""
# For connection caching we cache only based on host.
parts = url.split('/', 3)
if len(parts) > 3:
host = '/'.join(parts[:-1])
else:
host = url
self._connections[host] = b'cached'
def postJSON(self, url, body, headers=None):
"""
See: PersistentAgent.
"""
if headers is None:
headers = {}
# Make sure content type is always JSON.
headers['Content-Type'] = [b'application/json; charset=utf-8']
return self.post(url, headers, json.dumps(body).encode('utf-8'))
def post(self, url, headers, body):
"""
See: PersistentAgent.
"""
return self._getResponse(b'POST', url, headers, body)
def put(self, url, headers, body):
"""
See: PersistentAgent.
"""
return self._getResponse(b'PUT', url, headers, body)
def get(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'GET', url, headers, body=None)
def delete(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'DELETE', url, headers, body=None)
def propfind(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'PROPFIND', url, headers, body=None)
def mkcol(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'MKCOL', url, headers, body=None)
def head(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'HEAD', url, headers, body=None)
def move(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'MOVE', url, headers, body=None)
def readBody(self, response):
"""
Return a deferred which fires with response body.
"""
return defer.succeed(response.response_body)
def closePersistentConnections(self):
"""
See: PersistentAgent.
"""
self._connections = {}
@property
def connections(self):
"""
Return the list of persistent connections.
"""
return self._connections
class DummyHTTPChannel(object):
port = 80
disconnected = False
def __init__(self, site=None, peer=None, host=None, resource=None):
from chevah.server.testing import mk
self.written = StringIO()
self.producers = []
if peer is None:
peer = mk.makeIPv4Address(host='42.0.0.1', port=1234)
if host is None:
host = mk.makeIPv4Address(host='10.0.0.1', port=self.port)
if site is None:
site = web_server.Site(resource)
self.site = site
self._peer = peer
self._host = host
def getPeer(self):
return self._peer
def write(self, data):
if not isinstance(data, bytes):
raise AssertionError('Non bytes write request.')
self.written.write(bytes)
def writeSequence(self, iovec):
map(self.write, iovec)
def getHost(self):
return self._host
def registerProducer(self, producer, streaming):
self.producers.append((producer, streaming))
def loseConnection(self):
self.disconnected = True
def requestDone(self, request):
pass
@implementer(internet_interfaces.ISSLTransport)
class DummyHTTPSChannel(DummyHTTPChannel):
port = 443
class DummyWebRequest(Request):
"""
A dummy Twisted Web Request used in tests.
uri - url encoded
postpath, prepath -, url decoded, utf-8 encoded (not Unicode)
`prepath` is what was already traversed.
`postpath` is what is left to be traversed.
Set 'prepath' to [] to prevent automatic traversal.
"""
def __init__(
self,
postpath=None, prepath=None, session=None, resource=None,
data=None, peer=None, site=None,
uri=None, clientproto=None, method=b'GET', secured=False,
path=None, host=None, channel=None
):
from chevah.server.testing import mk
if peer is None:
peer = mk.makeIPv4Address(host='42.0.0.1')
if channel is None:
channel = DummyHTTPChannel(peer=peer, host=host, site=site)
self.channel = channel
self.site = channel.site
self.content = StringIO()
self.written = []
if data:
self.content.write(data)
self.content.seek(0)
# Full URL including arguments
if uri is None:
uri = b'/uri-not-defined'
elif postpath is not None:
raise AssertionError('You can not define both URI and postpath.')
if isinstance(uri, unicode):
# For testing we accept Unicode URI but internally we have to
# encode it.
self.uri = urllib.quote(uri.encode('utf-8'))
else:
self.uri = uri
# HTTP URL arguments POST or GET
self.args = {}
# HTTP URL without arguments
self.path = path
if prepath is None:
if self.uri:
prepath = self.uri.split('/')
else:
prepath = []
self.prepath = []
for part in prepath:
if isinstance(part, unicode):
part = part.encode('utf-8')
self.prepath.append(part)
if postpath is None:
# FIXME:4568:
# Here we should not ignore the first segment of the URI.
# The code started only for RESTFolder which is accessed as /home/
# and then it ignores the home.
self.postpath = [p.encode('utf-8') for p in uri.split('/')[1:]]
else:
self.postpath = []
for part in postpath:
if isinstance(part, unicode):
part = part.encode('utf-8')
self.postpath.append(part)
self.uri = urllib.quote(
'/%s' % '/'.join(self.postpath + self.prepath))
if isinstance(self.uri, unicode):
raise AssertionError('URI should be URL encoded.')
self.sitepath = []
self.client = peer
self.secured = secured
if clientproto is None:
clientproto = 'HTTP/1.0'
self.clientproto = clientproto
self.method = method.upper().encode('ascii')
self.session = None
self.protoSession = session or web_server.Session(0, self)
self._code = http.OK
self._code_message = b'OK'
self.responseHeaders = Headers()
self.requestHeaders = Headers()
# This should be called after we have defined the request headers.
if host is None:
host = b'dummy.host.tld'
self.setRequestHeader(b'host', host)
self.received_cookies = {}
self.cookies = [] # outgoing cookies
# Finish notifications
self.notifications = []
self.finished = 0
self.queued = False
def __repr__(self):
return (
'DummyWebRequest for "%(uri)s", code: %(code)s\n'
'response content: "%(response_content)s"\n'
'response headers: "%(response_headers)s' % ({
'uri': self.uri,
'code': self.code,
'response_content': self.test_response_content,
'response_headers': dict(
self.responseHeaders.getAllRawHeaders()),
})
)
@property
def code(self):
return self._code
@property
def code_message(self):
return self._code_message
@property
def test_response_content(self):
"""
Return the data written to the request during tests.
"""
return ''.join(self.written)
@property
def written_content(self):
"""
All the data written to the request.
"""
return b''.join(self.written)
def _checkContentNotWritten(self, method_name):
"""
Check that content was not sent to remote peer.
`method_name` is there to help pinpoint where things went wrong.
"""
assert not self.written, (
"%s cannot be called after data has been written: %s." % (
method_name, "@@@@".join(self.written)))
def getSession(self, sessionInterface=None):
"""
Return server sided stored session object.
"""
if not self.session:
self._checkContentNotWritten('getSession')
return super(DummyWebRequest, self).getSession(sessionInterface)
def write(self, data):
if not isinstance(data, bytes):
raise AssertionError("Non bytes write request.")
self.written.append(data)
def isSecure(self):
return self.secured
def getHeader(self, name):
"""
Public method for a Request.
"""
value = self.requestHeaders.getRawHeaders(name)
if value is not None:
return value[-1]
def setHeader(self, name, value):
"""
Public method for a Request.
"""
self.responseHeaders.setRawHeaders(name, [value])
def getRequestHeader(self, name):
"""
Testing method to get request headers.
This is here so that we can have clear/explicit tests.
"""
return self.getHeader(name)
def setRequestHeader(self, name, value):
"""
Testing method to set request headers.
This is here so that we can have clear/explicit tests.
"""
self.requestHeaders.setRawHeaders(name.lower(), [value])
def getResponseHeader(self, name):
"""
Testing method to get response headers.
This is here so that we can have clear/explicit tests.
"""
value = self.responseHeaders.getRawHeaders(name)
if value is not None:
return value[-1]
def setResponseHeader(self, name, value):
"""
Testing method to set response headers.
This is here so that we can have clear/explicit tests.
"""
self.setHeader(self, name, value)
def setLastModified(self, when):
"""
See: twisted.web.http.Request.
"""
self._checkContentNotWritten('setLastModified')
return super(DummyWebRequest, self).setLastModified(when)
def setETag(self, tag):
"""
See: twisted.web.http.Request.
Just checks that method is not called after response was sent.
"""
self._checkContentNotWritten('setETag')
def getCookie(self, key):
"""
Get an incoming HTTP cookie.
"""
return self.received_cookies.get(key, None)
def setReceivedCookie(self, key, value):
"""
Set cookie as it were received in the request.
"""
self.received_cookies[key] = value
def redirect(self, url):
"""
Utility function that does a redirect.
The request should have finish() called after this.
"""
self.setResponseCode(http.FOUND)
self.setHeader("location", url)
def registerProducer(self, prod, s):
"""
See: twisted.web.http.Request.
"""
self.go = 1
while self.go:
prod.resumeProducing()
def unregisterProducer(self):
"""
See: twisted.web.http.Request.
"""
self.go = 0
def processingFailed(self, reason):
"""
Errback and L{Deferreds} waiting for finish notification.
"""
if self.notifications is not None:
observers = self.notifications
self.notifications = None
for obs in observers:
obs.errback(reason)
def setResponseCode(self, code, message=None):
"""
Set the HTTP status response code, but takes care that this is called
before any data is written.
"""
self._checkContentNotWritten('setResponseCode')
# Code and code_message are read-only for our test purpose.
self._code = code
self._code_message = message
def render(self, resource):
"""
Render the given resource as a response to this request.
This implementation only handles a few of the most common behaviors of
resources. It can handle a render method that returns a string or
C{NOT_DONE_YET}. It doesn't know anything about the semantics of
request methods (eg HEAD) nor how to set any particular headers.
Basically, it's largely broken, but sufficient for some tests at
least.
It should B{not} be expanded to do all the same stuff L{Request} does.
Instead, L{DummyRequest} should be phased out and L{Request} (or some
other real code factored in a different way) used.
"""
result = resource.render(self)
if result is web_server.NOT_DONE_YET:
return
self.write(result)
self.finish()
def setRequestContent(self, data):
"""
Set body content of the request.
"""
self.content.write(data)
self.content.seek(0)
class _StoppableHTTPServer(HTTPServer):
"""
Single connection HTTP server designed to respond to HTTP requests in
functional tests.
"""
server_version = 'ChevahTesting/0.1'
stopped = False
# Current connection served by the server.
active_connection = None
def serve_forever(self):
"""
Handle one request at a time until stopped.
"""
self.stopped = False
self.active_connection = None
while not self.stopped:
try:
self.handle_request()
except SelectError as e:
# See Python http://bugs.python.org/issue7978
if e.args[0] == errno.EINTR:
continue
raise
except socket.error as error:
if error.errno == errno.EBADF:
self.stopped = True
else:
break
class _ThreadedHTTPServer(Thread):
"""
HTTP or HTTPS Server that runs in a thread.
This is actual a threaded wrapper around an HTTP server.
"""
TIMEOUT = 1
def __init__(
self, handler_class, cond,
responses=None,
ip='127.0.0.1', port=0,
server_certificate=None,
debug=False,
):
Thread.__init__(self)
self.ready = False
self.cond = cond
self._ip = ip
self._port = port
self._handler_class = handler_class
self._server_certificate = server_certificate
def run(self):
self.cond.acquire()
timeout = 0
self.httpd = None
while self.httpd is None:
try:
self.httpd = _StoppableHTTPServer(
(self._ip, self._port), self._handler_class)
if self._server_certificate:
self.httpd.socket = ssl.wrap_socket(
self.httpd.socket,
certfile=self._server_certificate,
server_side=True,
)
except Exception as e:
# I have no idea why this code works.
# It is a copy paste from:
# http://www.ianlewis.org/en/testing-using-mocked-server
import errno
import time
if (isinstance(e, socket.error) and
errno.errorcode[e.args[0]] == 'EADDRINUSE' and
timeout < self.TIMEOUT):
timeout += 1
time.sleep(1)
else:
self.cond.notifyAll()
self.cond.release()
self.ready = True
raise e
self.ready = True
if self.cond:
self.cond.notifyAll()
self.cond.release()
# Start the actual HTTP server.
self.httpd.serve_forever()
class HTTPServerContext(object):
"""
A context manager which runs a HTTP or HTTPS server for testing simple
HTTP requests.
After the server is started the ip and port are available in the
context management instance.
response = ResponseDefinition(url='/hello.html', response_content='Hello!)
with HTTPServerContext([response]) as httpd:
print 'Listening at %s:%d' % (httpd.id, httpd.port)
self.assertEqual('Hello!', your_get())
responses = ResponseDefinition(
url='/hello.php', request='user=John',
response_content='Hello John!, response_code=202)
with HTTPServerContext([response]) as httpd:
self.assertEqual(
'Hello John!',
get_you_post(url='hello.php', data='user=John'))
"""
def __init__(
self, responses=None, ip='127.0.0.1', port=0,
version='HTTP/1.1', cert=None, debug=False,
auth=None,
):
"""
Initialize a new HTTPServerContext.
* ip - IP to listen. Leave empty to listen to any interface.
* port - Port to listen. Leave 0 to pick a random port.
* server_version - HTTP version used by server.
* responses - A list of ResponseDefinition defining the behavior of
this server.
* cert - certificate to be used by the HTTPS server.
"""
self._previous_valid_responses = _DefinedRequestHandler.valid_responses
self._previous_first_client = _DefinedRequestHandler.first_client
self._previous_authorization = _DefinedRequestHandler.authorization
# Since we can not pass an instance of _DefinedRequestHandler
# we do on the fly patching here.
# Servers might be nested, so a debug will trigger any server.
# Also, this has the side effect, that once a debug is enabled in
# one tests, it will be enabled in all the tests after it, so we
# keep the state at start to see if we should revert.
self._debug = debug
self._previous_debug = _DefinedRequestHandler.debug
_DefinedRequestHandler.debug = _DefinedRequestHandler.debug or debug
if responses is None:
_DefinedRequestHandler.valid_responses = []
else:
_DefinedRequestHandler.valid_responses = responses
_DefinedRequestHandler.authorization = auth
_DefinedRequestHandler.protocol_version = version
self.cond = threading.Condition()
self._server_certificate = cert
self.server = _ThreadedHTTPServer(
handler_class=_DefinedRequestHandler,
cond=self.cond,
ip=ip,
port=port,
server_certificate=cert,
)
def __enter__(self):
self.cond.acquire()
self.server.start()
# Wait until the server is ready.
while not self.server.ready:
self.cond.wait()
self.cond.release()
# Even if the thread ready, it might still need some time
# to be ready.
time.sleep(0.03)
return self
def __exit__(self, exc_type, exc_value, tb):
# _DefinedRequestHandler initialization is outside of control so
# we share state as class members. To free memory we need to clean it.
_DefinedRequestHandler.cleanGlobals()
if self._debug and not self._previous_debug:
_DefinedRequestHandler.debug = False
self.stopServer()
self.server.join(1)
# The reverting should be done after the thread is closed so that
# we don't have things inside the thread setting the class values.
_DefinedRequestHandler.valid_responses = self._previous_valid_responses
_DefinedRequestHandler.first_client = self._previous_first_client
_DefinedRequestHandler.authorization = self._previous_authorization
if self.server.isAlive():
raise AssertionError('Server still running')
return False
@property
def port(self):
return self.server.httpd.server_address[1]
@property
def ip(self):
return self.server.httpd.server_address[0]
def stopServer(self):
connection = self.server.httpd.active_connection
try:
if connection and connection.rfile._sock:
# Stop waiting for data from persistent connection.
self.server.httpd.stopped = True
sock = connection.rfile._sock
sock.shutdown(socket.SHUT_RDWR)
sock.close()
else:
# Stop waiting for data from new connection.
# This is done by sending a special QUIT request without
# waiting for data.
address = '%s:%d' % (self.ip, self.port)
if self._server_certificate:
try:
context = ssl._create_unverified_context()
except Exception:
# Python < 2.7.9 uses the old HTTPS code.
context = None
if context:
# Use an HTTPS connection since we are in HTTPS mode.
conn = HTTPSConnection(
address, context=context, timeout=5)
else:
conn = HTTPSConnection(address, timeout=5)
else:
conn = HTTPConnection(address, timeout=5)
conn.request("QUIT", "/")
conn.getresponse()
except (socket.error, BadStatusLine):
# Ignore socket errors at shutdown as the connection
# might be already closed.
pass
self.server.httpd.server_close()
class _DefinedRequestHandler(BaseHTTPRequestHandler, object):
"""
A request handler which act based on pre-defined responses.
This should only be used for test together with HTTPServerContext.
"""
valid_responses = []
debug = False
# Keep a record of first client which connects to the request
# series to have a better check for persisted connections.
first_client = None
# Value of the authorization header required for requests.
authorization = None
def __init__(self, request, client_address, server):
if self.debug:
print('New connection %s.' % (client_address,))
# Register current connection on server.
server.active_connection = self
try:
super(_DefinedRequestHandler, self).__init__(
request, client_address, server)
except socket.error:
pass
server.active_connection = None
@classmethod
def cleanGlobals(cls):
"""
Clean all class methods used to share info between different requests.
"""
cls.valid_responses = []
cls.first_client = None
def log_message(self, *args):
"""
We don't want any logs from the server.
"""
pass
def do_QUIT(self):
"""
Called by HTTPServerContext to trigger server stop.
"""
self.server.stopped = True
self._debug('QUIT')
self.send_response(200)
self.end_headers()
# Force closing the connection.
self.close_connection = 1
def do_GET(self):
self._handleRequest()
def do_HEAD(self):
self._handleRequest()
def do_POST(self):
self._handleRequest()
def do_PUT(self):
self._handleRequest()
def do_DELETE(self):
self._handleRequest()
def do_PROPFIND(self):
self._handleRequest()
def _handleRequest(self):
"""
Check if we can handle the request and send response.
"""
if not self.first_client:
# Looks like this is the first request so we save the client
# address to compare it later.
self._debug('First: %r' % (self.client_address,))
self.__class__.first_client = self.client_address
response = self._matchResponse()
if response:
self._debug(response)
self._sendResponse(response)
self._debug(
'After: Close-connection: %s' % (self.close_connection,))
return
self.send_error(404)
def _matchResponse(self):
"""
Return the ResponseDefinition for the current request.
Return `None` if no match was found.
"""
length = int(self.headers.getheader('content-length', 0))
content = self.rfile.read(length)
self._debug(content)
auth_request = self.headers.getheader('authorization', None)
if self.authorization != auth_request: # noqa:cover
error = 'Request not authorized. Expect %s got %s' % (
self.authorization, auth_request)
self._debug(error)
return ResponseDefinition(
response_code=401,
response_message='Not authorized',
response_content=error,
)
any_matching = None
for response in self.__class__.valid_responses:
if self.path != response.url or self.command != response.method:
continue
# For GET we have the response and don't check further.
if self.command == 'GET':
return response
# We don't need exact content matching
if response.request is ResponseDefinition.ANY:
self._debug('ANY matching.')
# Keep the response, but see if we have a more specific
# match.
any_matching = response
continue
if content == response.request:
return response
return any_matching
def _debug(self, message=''):
"""
Print to stdout a debug message.
"""
if not self.debug:
return
print('\nGot %s:%s - %s\n' % (
self.command, self.path, message))
def _sendResponse(self, response):
"""
Send response to client.
"""
connection_header = self.headers.getheader('connection')
if connection_header:
connection_header = connection_header.lower()
if self.protocol_version == 'HTTP/1.1':
# For HTTP/1.1 connections are persistent by default.
if not connection_header:
connection_header = 'keep-alive'
else:
# For HTTP/1.0 connections are not persistent by default.
if not connection_header:
connection_header = 'close'
if response.persistent is None:
# Ignore persistent flag.
pass
elif response.persistent:
if connection_header == 'close':
self.send_error(400, 'Headers do not persist the connection')
if self.first_client != self.client_address:
self.send_error(
400,
'Persistent connection not reused. First %r. Now %r' % (
self.first_client, self.client_address))
else:
if connection_header == 'keep-alive':
self.send_error(400, 'Connection was persistent')
if isinstance(response.response_code, Exception):
# Just close the connection without a valid HTTP response.
self.wfile.write(response.response_code.message)
self.close_connection = 1
return
self.send_response(
response.response_code, response.response_message)
self.send_header("Content-Type", response.content_type)
if response.response_length:
self.send_header("Content-Length", response.response_length)
self.end_headers()
self.wfile.write(response.test_response_content)
if not response.response_persistent:
# Force closing the connection as requested
# by response.
self.close_connection = 1
class ResponseDefinition(object):
"""
A class encapsulating the required data for configuring a response
generated by the HTTPServerContext.
It contains the following data:
* url - url that will trigger this response
* request - Content of the request that will trigger the response once
the url is matched
- Set to None to trigger on any content.
* response_content - content of the response
* response_code - HTTP code of the response or an Exception().
If the Exception has any text, it is written to the response as raw
data.
* response_message - Message sent together with HTTP code.
* content_type - Content type of the HTTP response
* response_length - Length of the response body content.
`None` to calculate automatically the length.
`` (empty string) to ignore content-length header.
* persistent: whether the request should persist the connection.
Set to None to ignore persistent checking.
"""
ANY = object()
def __init__(
self, url='', request='', method='GET',
response_content=b'', response_code=200, response_message=None,
content_type=b'text/html', response_length=None,
persistent=True, response_persistent=None,
):
self.url = url
self.method = method
self.request = request
if not isinstance(response_content, bytes):
response_content = codecs.encode(response_content, 'utf-8')
self.test_response_content = response_content
self.response_code = response_code
self.response_message = response_message
self.content_type = content_type
if response_length is None:
response_length = len(self.test_response_content)
self.response_length = str(response_length)
self.persistent = persistent
if response_persistent is None:
response_persistent = persistent
self.response_persistent = response_persistent
def __repr__(self):
return 'ResponseDefinition:%s:%s:%s %s:pers-%s' % (
self.url,
self.method,
self.response_code, self.response_message,
self.persistent,
)
def updateResponseContent(self, content):
"""
Will update the content returned to the server.
"""
if not isinstance(content, bytes):
content = codecs.encode(content, 'utf-8')
self.test_response_content = content
response_length = len(self.test_response_content)
self.response_length = str(response_length)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment