Skip to content

Instantly share code, notes, and snippets.

@adiroiban
Last active November 26, 2017 10:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adiroiban/edc0776e3337d0bd3f093aa0f2819deb to your computer and use it in GitHub Desktop.
Save adiroiban/edc0776e3337d0bd3f093aa0f2819deb to your computer and use it in GitHub Desktop.
Twisted Mail SMTPConnectError strange behaviour
# Copyright (c) 2015 Adi Roiban.
# MIT License
"""
(E)SMTP high level client.
"""
from email.header import Header
from email.mime.application import MIMEApplication
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formataddr, formatdate, make_msgid
from StringIO import StringIO
from time import time
from twisted.internet import defer, error, reactor
from twisted.internet.error import ConnectError
from twisted.internet.protocol import connectionDone
from twisted.mail import smtp
class ESMTPClient(object):
"""
High level API for sending messages over the ESMTP protocol.
Supports only ascii username/password combinations.
FIXME:3067:
Twisted does not allow unicode usernames/passwords.
"""
_reactor = reactor
# For testing purposes.
_boundary = None
def __init__(
self,
host,
port=25,
username=None,
password=None,
retries=3,
timeout=15,
):
self._host = host
self._port = port
self._username = username
self._password = password
self._retries = retries
self._timeout = timeout
def _formatAddressesAsHeader(self, addresses):
"""
Returns an email Header instance containing a comma separated list
of mailboxes`.
`Addresses` is a list of tuples in the form (name, email). ASCII
only email addresses are supported.
https://tools.ietf.org/html/rfc788#page-25
http://www.rfc-base.org/txt/rfc-2821.txt
"""
# FIXME:3067:
# Implement support for non ASCII email name/address.
charset = 'us-ascii'
header = Header(charset=charset)
for address in addresses:
item = '%s,' % formataddr(address)
header.append(item, charset=charset, errors='strict')
return header
def _buildContent(
self,
sender,
recipients,
subject,
body,
id=None,
attachments=None,
):
"""
Returns plain text email data header and body as a utf-8 encoded
byte string.
Unicode subject and body values are supported and will be encoded
as utf-8.
"""
if not id:
id = make_msgid()
if not attachments:
attachments = {}
destinations = []
for email in recipients:
# FIXME:3068:
# Named email addresses not supported.
destinations.append(('', email))
# FIXME:3069:
# Configurable sender name not supported.
source = [('TwistedClient', sender)]
charset = 'utf-8'
message = MIMEMultipart(boundary=self._boundary)
message['From'] = self._formatAddressesAsHeader(source)
message['To'] = self._formatAddressesAsHeader(destinations)
message['Subject'] = Header(subject, charset=charset)
message['Date'] = formatdate(time(), localtime=True)
message['Message-Id'] = id
message_body = MIMEText(body.encode(charset), 'plain', charset)
message.attach(message_body)
for name, attachment in attachments.items():
attachment.seek(0)
part = MIMEApplication(attachment.read(), Name=name)
part['Content-Disposition'] = (
u'attachment; filename="%s"' % (name,))
message.attach(part)
return message.as_string()
def send(
self,
sender,
recipients,
subject,
body,
id=None,
attachments=None,
):
"""
Sends email with `subject`, `body` and (optionally) `attachments`
to `recipients`, from specified `sender`.
`sender` is a string email address in the format
`address@example.com`.
`recipients` a list of string email addresses in the format
`address@example.com`.
`attachments` is a dictionary each key of which contains a
file name and each value contains file-like object
(e.g. `io.BytesIO`).
Returns a deferred that will fire once the email is send and the
server connection terminated. If email was delivered to all specified
recipients the result will be `True`, otherwise `False`.
For each call a new connection to the (E)SMTP server is made and
subsequently dropped.
"""
connector = None
def bb_disconnect(result_or_failure):
"""
If not connected, return `result_or_failure` immediately;
otherwise wait for disconnection from server to occur and return
`result_or_failure`.
"""
if not connector:
return result_or_failure
if not connector.transport:
# Not yet connected.
return result_or_failure
if not connector.factory.currentProtocol:
# Connected, but STMP was not started or was already stopped.
return result_or_failure
# Wait for current connection to close.
deferred = connector.factory.currentProtocol.disconnected_deferred
deferred.addCallback(lambda _: result_or_failure)
return deferred
send_deferred = defer.Deferred()
send_deferred.addCallback(self._cbSend, recipients)
send_deferred.addErrback(self._ebSend)
send_deferred.addBoth(bb_disconnect)
content = self._buildContent(
sender, recipients, subject, body, id=id, attachments=attachments)
factory = SenderFactory(
username=self._username,
password=self._password,
fromEmail=sender,
toEmail=recipients,
file=StringIO(content),
deferred=send_deferred,
retries=self._retries,
timeout=self._timeout,
)
connector = self._reactor.connectTCP(
self._host, self._port, factory, timeout=self._timeout)
return send_deferred
def _cbSend(self, result, recipients):
"""
Called when email was delivered successfully.
Returns `True` if mail was delivered to all recipients, `False`
otherwise.
FIXME:3071:
Return detailed success/failure results.
"""
successful, details = result
if len(recipients) == successful:
# Email delivered successfully to all recipients.
return True
return False
def _ebSend(self, failure):
"""
Called when email delivery failed or there are connection errors.
"""
failure.trap(ConnectError, smtp.SMTPClientError)
# FIXME:3070:
# Create email client dedicated exceptions a la HTTPException.
if isinstance(failure.value, smtp.SMTPClientError):
message = u'%d %s' % (
failure.value.code, unicode(failure.value.resp))
else:
message = unicode(failure.getErrorMessage())
raise RuntimeError(message.encode('utf-8'))
class ESMTPClientProtocol(smtp.ESMTPSender):
"""
A ESMTP client protocol that fires an internal deferred once connection
with server is terminated.
"""
#: A deferred which is called when protocol is not connected.
disconnected_deferred = defer.succeed(None)
def connectionMade(self):
"""
See: `SMTPSender`.
"""
smtp.ESMTPSender.connectionMade(self)
self.disconnected_deferred = defer.Deferred()
def connectionLost(self, reason=connectionDone):
"""
See: `SMTPClient`.
"""
smtp.SMTPClient.connectionLost(self, reason)
self.disconnected_deferred.callback(None)
class SenderFactory(smtp.ESMTPSenderFactory):
"""
A factory for simple sending of emails.
"""
protocol = ESMTPClientProtocol
def __init__(
self, username, password, fromEmail, toEmail, file, deferred,
retries, timeout,
):
if username:
encoded_username = username.encode('utf-8')
else:
encoded_username = None
if password:
encoded_password = password.encode('utf-8')
else:
encoded_password = None
smtp.ESMTPSenderFactory.__init__(
self,
username=encoded_username,
password=encoded_password,
fromEmail=fromEmail,
toEmail=toEmail,
file=file,
deferred=deferred,
retries=retries,
timeout=timeout,
requireAuthentication=False,
requireTransportSecurity=False,
heloFallback=1,
)
def _processConnectionError(self, connector, err):
self.currentProtocol = None
if (self.retries < 0) and (not self.sendFinished):
# Rewind the file in case part of it was read while attempting to
# send the message.
self.file.seek(0, 0)
connector.connect()
self.retries += 1
elif not self.sendFinished:
# If we were unable to communicate with the SMTP server a
# ConnectionDone will be returned.
# We want a more clear error message for debugging
if err.check(error.ConnectionDone):
err.value = smtp.SMTPConnectError(
-1, "Unable to connect to server.")
self.result.errback(err.value)
# Copyright (c) 2014 Adi Roiban.
# MIT License
"""
Helpers for testing code related to HTTP protocol.
"""
from __future__ import absolute_import
from httplib import HTTPConnection, HTTPSConnection
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
from select import error as SelectError
from StringIO import StringIO
from threading import Thread
import codecs
import errno
import json
from httplib import BadStatusLine
import socket
import ssl
import threading
import time
import urllib
from twisted.internet import defer, interfaces as internet_interfaces
from twisted.python.failure import Failure
from twisted.web import (
http,
server as web_server,
)
from twisted.web.client import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer
from twisted.web.server import Request
from zope.interface import implementer
# Marker used to signal that the expected request will have a streamed body.
STREAMED_REQUEST = object()
STREAMED_RESPONSE = object()
@implementer(internet_interfaces.IPushProducer)
class InMemoryBodyDeliverer(object):
"""
A wrapper to externalized the stubbing of data received by the response.
"""
def __init__(self, consumer):
self._consumer = consumer
self._consumer.transport = self
self.stopped = False
self.paused = False
def pauseProducing(self):
self.paused = True
def resumeProducing(self):
self.paused = False
def stopProducing(self):
self.stopped = True
def sendData(self, data):
"""
Simulate a data received from the response.
"""
if self.stopped:
raise AssertionError('Data received after producer was stopped.')
if self.paused:
raise AssertionError('Data received while paused.')
self._consumer.dataReceived(data)
def stopSending(self):
"""
Simulate that the body delivered has sent all the data, as signaled
by closing the connection.
"""
self._consumer.connectionLost(Failure(ResponseDone()))
class InMemoryResponse(object):
"""
Simple response from in memory agent request.
"""
def __init__(
self, method='GET', code=404, phrase='Phrase not set',
request_body=None, request_json=None, request_headers=None,
response_body=b'', response_headers={},
):
self.method = method
self.code = code
self.phrase = phrase
self.request_headers = Headers(request_headers)
if request_body is not None:
request_body = request_body
elif request_json is not None:
request_body = json.dumps(request_json)
else:
request_body = ''
self.request_body = request_body
self.response_body = response_body
self.headers = Headers(response_headers)
self._stream_request = False
self._streamed_body = None
@property
def streamed_body(self):
"""
The content of the requested as it was streamed.
"""
return self._streamed_body
def write(self, data):
"""
Called when received a streamed request.
"""
if self.request_body != STREAMED_REQUEST:
raise AssertionError('Write while not in stream mode.')
if not isinstance(data, bytes):
raise AssertionError('Write operation for non bytes.')
if self._streamed_body is None:
self._streamed_body = data
else:
self._streamed_body += data
def deliverBody(self, consumer_protocol):
"""
See: twisted.web.client.Response.deliverBody.
"""
producer = InMemoryBodyDeliverer(consumer_protocol)
if self.response_body == STREAMED_RESPONSE:
self.body_producer = producer
return
consumer_protocol.dataReceived(self.response_body)
consumer_protocol.connectionLost(Failure(ResponseDone()))
class InMemoryPersistentAgent(object):
"""
Simple implementation of an HTTP Persistent Agent.
"""
def __init__(self):
self._connections = {}
self.responses = {}
def _getResponse(self, method, url, headers, body):
"""
Return the arranged response for our rigged agent.
"""
if body is None:
body = ''
self._cacheURL(url)
headers = Headers(headers)
try:
response = self.responses[url].pop(0)
except (KeyError, IndexError):
raise AssertionError(
'URL not defined %s for %s' % (method, url))
if isinstance(response, Exception):
return defer.fail(response)
if isinstance(response, defer.Deferred):
deferred = response
else:
deferred = defer.succeed(response)
def cb_check_response(response):
"""
Called when we got a response.
"""
if response.method != method:
raise AssertionError(
'No %s request was defined for %s' % (method, url))
if IBodyProducer.providedBy(body):
# We have a streamed request, so the body is checked later.
if response.request_body != STREAMED_REQUEST:
raise AssertionError(
'Unexpected streamed request. Expecting: %s' % (
response.request_body,))
done_deferred = body.startProducing(response)
response.start_producing_deferred = done_deferred
else:
# Just compare the whole body.
if response.request_body != body:
raise AssertionError(
'Invalid request body:\ngot\n%s\nexpecting\n%s' % (
body, response.request_body))
if response.request_headers != headers:
raise AssertionError(
'Invalid headers:\ngot\n%s\nexpecting\n%s' % (
headers, response.request_headers))
# All good.
return response
deferred.addCallback(cb_check_response)
return deferred
def _cacheURL(self, url):
"""
Called to mark that URL is cached.
"""
# For connection caching we cache only based on host.
parts = url.split('/', 3)
if len(parts) > 3:
host = '/'.join(parts[:-1])
else:
host = url
self._connections[host] = b'cached'
def postJSON(self, url, body, headers=None):
"""
See: PersistentAgent.
"""
if headers is None:
headers = {}
# Make sure content type is always JSON.
headers['Content-Type'] = [b'application/json; charset=utf-8']
return self.post(url, headers, json.dumps(body).encode('utf-8'))
def post(self, url, headers, body):
"""
See: PersistentAgent.
"""
return self._getResponse(b'POST', url, headers, body)
def put(self, url, headers, body):
"""
See: PersistentAgent.
"""
return self._getResponse(b'PUT', url, headers, body)
def get(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'GET', url, headers, body=None)
def delete(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'DELETE', url, headers, body=None)
def propfind(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'PROPFIND', url, headers, body=None)
def mkcol(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'MKCOL', url, headers, body=None)
def head(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'HEAD', url, headers, body=None)
def move(self, url, headers):
"""
See: PersistentAgent.
"""
return self._getResponse(b'MOVE', url, headers, body=None)
def readBody(self, response):
"""
Return a deferred which fires with response body.
"""
return defer.succeed(response.response_body)
def closePersistentConnections(self):
"""
See: PersistentAgent.
"""
self._connections = {}
@property
def connections(self):
"""
Return the list of persistent connections.
"""
return self._connections
class DummyHTTPChannel(object):
port = 80
disconnected = False
def __init__(self, site=None, peer=None, host=None, resource=None):
from chevah.server.testing import mk
self.written = StringIO()
self.producers = []
if peer is None:
peer = mk.makeIPv4Address(host='42.0.0.1', port=1234)
if host is None:
host = mk.makeIPv4Address(host='10.0.0.1', port=self.port)
if site is None:
site = web_server.Site(resource)
self.site = site
self._peer = peer
self._host = host
def getPeer(self):
return self._peer
def write(self, data):
if not isinstance(data, bytes):
raise AssertionError('Non bytes write request.')
self.written.write(bytes)
def writeSequence(self, iovec):
map(self.write, iovec)
def getHost(self):
return self._host
def registerProducer(self, producer, streaming):
self.producers.append((producer, streaming))
def loseConnection(self):
self.disconnected = True
def requestDone(self, request):
pass
@implementer(internet_interfaces.ISSLTransport)
class DummyHTTPSChannel(DummyHTTPChannel):
port = 443
class DummyWebRequest(Request):
"""
A dummy Twisted Web Request used in tests.
uri - url encoded
postpath, prepath - utf-8 encoded (not Unicode)
"""
def __init__(
self,
postpath=None, prepath=None, session=None, resource=None,
data=None, peer=None, site=None,
uri=None, clientproto=None, method=None, secured=False,
path=None, host=None, channel=None
):
from chevah.server.testing import mk
if peer is None:
peer = mk.makeIPv4Address(host='42.0.0.1')
if channel is None:
channel = DummyHTTPChannel(peer=peer, host=host, site=site)
self.channel = channel
self.site = channel.site
self.content = StringIO()
self.written = []
if data:
self.content.write(data)
self.content.seek(0)
# Full URL including arguments
if uri is None:
uri = '/uri-not-defined'
elif postpath is not None:
raise AssertionError('You can not define both URI and postpath.')
self.uri = uri
# HTTP URL arguments POST or GET
self.args = {}
# HTTP URL without arguments
self.path = path
if prepath is None:
prepath = []
if postpath is None:
postpath = self.uri.split('/')[1:]
else:
self.uri = urllib.quote('/%s' % '/'.join(postpath + prepath))
if isinstance(uri, unicode):
raise AssertionError('URI should be URL encoded.')
self.sitepath = []
self.prepath = prepath
self.postpath = postpath
self.client = peer
self.secured = secured
if clientproto is None:
clientproto = 'HTTP/1.0'
self.clientproto = clientproto
if method is None:
method = 'GET'
self.method = method.upper()
self.session = None
self.protoSession = session or web_server.Session(0, self)
self._code = http.OK
self._code_message = 'OK'
self.responseHeaders = Headers()
self.requestHeaders = Headers()
# This should be called after we have defined the request headers.
if host is None:
host = 'dummy.host.tld'
self.setRequestHeader('host', host)
self.received_cookies = {}
self.cookies = [] # outgoing cookies
# Finish notifications
self.notifications = []
self.finished = 0
self.queued = False
def __repr__(self):
return (
'DummyWebRequest for "%(uri)s", code: %(code)s\n'
'response content: "%(response_content)s"\n'
'response headers: "%(response_headers)s' % ({
'uri': self.uri,
'code': self.code,
'response_content': self.test_response_content,
'response_headers': dict(
self.responseHeaders.getAllRawHeaders()),
})
)
@property
def code(self):
return self._code
@property
def code_message(self):
return self._code_message
@property
def test_response_content(self):
"""
Return the data written to the request during tests.
"""
return ''.join(self.written)
def _checkContentNotWritten(self, method_name):
"""
Check that content was not sent to remote peer.
`method_name` is there to help pinpoint where things went wrong.
"""
assert not self.written, (
"%s cannot be called after data has been written: %s." % (
method_name, "@@@@".join(self.written)))
def getSession(self, sessionInterface=None):
"""
Return server sided stored session object.
"""
if not self.session:
self._checkContentNotWritten('getSession')
return super(DummyWebRequest, self).getSession(sessionInterface)
def write(self, data):
if not isinstance(data, bytes):
raise AssertionError("Non bytes write request.")
self.written.append(data)
def isSecure(self):
return self.secured
def getHeader(self, name):
"""
Public method for a Request.
"""
value = self.requestHeaders.getRawHeaders(name)
if value is not None:
return value[-1]
def setHeader(self, name, value):
"""
Public method for a Request.
"""
self.responseHeaders.setRawHeaders(name, [value])
def getRequestHeader(self, name):
"""
Testing method to get request headers.
This is here so that we can have clear/explicit tests.
"""
return self.getHeader(name)
def setRequestHeader(self, name, value):
"""
Testing method to set request headers.
This is here so that we can have clear/explicit tests.
"""
self.requestHeaders.setRawHeaders(name.lower(), [value])
def getResponseHeader(self, name):
"""
Testing method to get response headers.
This is here so that we can have clear/explicit tests.
"""
value = self.responseHeaders.getRawHeaders(name)
if value is not None:
return value[-1]
def setResponseHeader(self, name, value):
"""
Testing method to set response headers.
This is here so that we can have clear/explicit tests.
"""
self.setHeader(self, name, value)
def setLastModified(self, when):
"""
See: twisted.web.http.Request.
"""
self._checkContentNotWritten('setLastModified')
return super(DummyWebRequest, self).setLastModified(when)
def setETag(self, tag):
"""
See: twisted.web.http.Request.
Just checks that method is not called after response was sent.
"""
self._checkContentNotWritten('setETag')
def getCookie(self, key):
"""
Get an incoming HTTP cookie.
"""
return self.received_cookies.get(key, None)
def setReceivedCookie(self, key, value):
"""
Set cookie as it were received in the request.
"""
self.received_cookies[key] = value
def redirect(self, url):
"""
Utility function that does a redirect.
The request should have finish() called after this.
"""
self.setResponseCode(http.FOUND)
self.setHeader("location", url)
def registerProducer(self, prod, s):
"""
See: twisted.web.http.Request.
"""
self.go = 1
while self.go:
prod.resumeProducing()
def unregisterProducer(self):
"""
See: twisted.web.http.Request.
"""
self.go = 0
def processingFailed(self, reason):
"""
Errback and L{Deferreds} waiting for finish notification.
"""
if self.notifications is not None:
observers = self.notifications
self.notifications = None
for obs in observers:
obs.errback(reason)
def setResponseCode(self, code, message=None):
"""
Set the HTTP status response code, but takes care that this is called
before any data is written.
"""
self._checkContentNotWritten('setResponseCode')
# Code and code_message are read-only for our test purpose.
self._code = code
self._code_message = message
def render(self, resource):
"""
Render the given resource as a response to this request.
This implementation only handles a few of the most common behaviors of
resources. It can handle a render method that returns a string or
C{NOT_DONE_YET}. It doesn't know anything about the semantics of
request methods (eg HEAD) nor how to set any particular headers.
Basically, it's largely broken, but sufficient for some tests at
least.
It should B{not} be expanded to do all the same stuff L{Request} does.
Instead, L{DummyRequest} should be phased out and L{Request} (or some
other real code factored in a different way) used.
"""
result = resource.render(self)
if result is web_server.NOT_DONE_YET:
return
self.write(result)
self.finish()
def setRequestContent(self, data):
"""
Set body content of the request.
"""
self.content.write(data)
self.content.seek(0)
class _StoppableHTTPServer(HTTPServer):
"""
Single connection HTTP server designed to respond to HTTP requests in
functional tests.
"""
server_version = 'ChevahTesting/0.1'
stopped = False
# Current connection served by the server.
active_connection = None
def serve_forever(self):
"""
Handle one request at a time until stopped.
"""
self.stopped = False
self.active_connection = None
while not self.stopped:
try:
self.handle_request()
except SelectError as e:
# See Python http://bugs.python.org/issue7978
if e.args[0] == errno.EINTR:
continue
raise
except socket.error as error:
if error.errno == errno.EBADF:
self.stopped = True
else:
break
class _ThreadedHTTPServer(Thread):
"""
HTTP or HTTPS Server that runs in a thread.
This is actual a threaded wrapper around an HTTP server.
"""
TIMEOUT = 1
def __init__(
self, handler_class, cond,
responses=None,
ip='127.0.0.1', port=0,
server_certificate=None,
debug=False,
):
Thread.__init__(self)
self.ready = False
self.cond = cond
self._ip = ip
self._port = port
self._handler_class = handler_class
self._server_certificate = server_certificate
def run(self):
self.cond.acquire()
timeout = 0
self.httpd = None
while self.httpd is None:
try:
self.httpd = _StoppableHTTPServer(
(self._ip, self._port), self._handler_class)
if self._server_certificate:
self.httpd.socket = ssl.wrap_socket(
self.httpd.socket,
certfile=self._server_certificate,
server_side=True,
)
except Exception as e:
# I have no idea why this code works.
# It is a copy paste from:
# http://www.ianlewis.org/en/testing-using-mocked-server
import errno
import time
if (isinstance(e, socket.error) and
errno.errorcode[e.args[0]] == 'EADDRINUSE' and
timeout < self.TIMEOUT):
timeout += 1
time.sleep(1)
else:
self.cond.notifyAll()
self.cond.release()
self.ready = True
raise e
self.ready = True
if self.cond:
self.cond.notifyAll()
self.cond.release()
# Start the actual HTTP server.
self.httpd.serve_forever()
class HTTPServerContext(object):
"""
A context manager which runs a HTTP or HTTPS server for testing simple
HTTP requests.
After the server is started the ip and port are available in the
context management instance.
response = ResponseDefinition(url='/hello.html', response_content='Hello!)
with HTTPServerContext([response]) as httpd:
print 'Listening at %s:%d' % (httpd.id, httpd.port)
self.assertEqual('Hello!', your_get())
responses = ResponseDefinition(
url='/hello.php', request='user=John',
response_content='Hello John!, response_code=202)
with HTTPServerContext([response]) as httpd:
self.assertEqual(
'Hello John!',
get_you_post(url='hello.php', data='user=John'))
"""
def __init__(
self, responses=None, ip='127.0.0.1', port=0,
version='HTTP/1.1', cert=None, debug=False,
):
"""
Initialize a new HTTPServerContext.
* ip - IP to listen. Leave empty to listen to any interface.
* port - Port to listen. Leave 0 to pick a random port.
* server_version - HTTP version used by server.
* responses - A list of ResponseDefinition defining the behavior of
this server.
* cert - certificate to be used by the HTTPS server.
"""
self._previous_valid_responses = _DefinedRequestHandler.valid_responses
self._previous_first_client = _DefinedRequestHandler.first_client
# Since we can not pass an instance of _DefinedRequestHandler
# we do on the fly patching here.
# Servers might be nested, so a debug will trigger any server.
# Also, this has the side effect, that once a debug is enabled in
# one tests, it will be enabled in all the tests after it, so we
# keep the state at start to see if we should revert.
self._debug = debug
self._previous_debug = _DefinedRequestHandler.debug
_DefinedRequestHandler.debug = _DefinedRequestHandler.debug or debug
if responses is None:
_DefinedRequestHandler.valid_responses = []
else:
_DefinedRequestHandler.valid_responses = responses
_DefinedRequestHandler.protocol_version = version
self.cond = threading.Condition()
self._server_certificate = cert
self.server = _ThreadedHTTPServer(
handler_class=_DefinedRequestHandler,
cond=self.cond,
ip=ip,
port=port,
server_certificate=cert,
)
def __enter__(self):
self.cond.acquire()
self.server.start()
# Wait until the server is ready.
while not self.server.ready:
self.cond.wait()
self.cond.release()
# Even if the thread ready, it might still need some time
# to be ready.
time.sleep(0.02)
return self
def __exit__(self, exc_type, exc_value, tb):
# _DefinedRequestHandler initialization is outside of control so
# we share state as class members. To free memory we need to clean it.
_DefinedRequestHandler.cleanGlobals()
if self._debug and not self._previous_debug:
_DefinedRequestHandler.debug = False
self.stopServer()
self.server.join(1)
# The reverting should be done after the thread is closed so that
# we don't have things inside the thread setting the class values.
_DefinedRequestHandler.valid_responses = self._previous_valid_responses
_DefinedRequestHandler.first_client = self._previous_first_client
if self.server.isAlive():
raise AssertionError('Server still running')
return False
@property
def port(self):
return self.server.httpd.server_address[1]
@property
def ip(self):
return self.server.httpd.server_address[0]
def stopServer(self):
connection = self.server.httpd.active_connection
try:
if connection and connection.rfile._sock:
# Stop waiting for data from persistent connection.
self.server.httpd.stopped = True
sock = connection.rfile._sock
sock.shutdown(socket.SHUT_RDWR)
sock.close()
else:
# Stop waiting for data from new connection.
# This is done by sending a special QUIT request without
# waiting for data.
address = '%s:%d' % (self.ip, self.port)
if self._server_certificate:
try:
context = ssl._create_unverified_context()
except Exception:
# Python < 2.7.9 uses the old HTTPS code.
context = None
if context:
# Use an HTTPS connection since we are in HTTPS mode.
conn = HTTPSConnection(
address, context=context, timeout=5)
else:
conn = HTTPSConnection(address, timeout=5)
else:
conn = HTTPConnection(address, timeout=5)
conn.request("QUIT", "/")
conn.getresponse()
except (socket.error, BadStatusLine):
# Ignore socket errors at shutdown as the connection
# might be already closed.
pass
self.server.httpd.server_close()
class _DefinedRequestHandler(BaseHTTPRequestHandler, object):
"""
A request handler which act based on pre-defined responses.
This should only be used for test together with HTTPServerContext.
"""
valid_responses = []
debug = False
# Keep a record of first client which connects to the request
# series to have a better check for persisted connections.
first_client = None
def __init__(self, request, client_address, server):
if self.debug:
print('New connection %s.' % (client_address,))
# Register current connection on server.
server.active_connection = self
try:
super(_DefinedRequestHandler, self).__init__(
request, client_address, server)
except socket.error:
pass
server.active_connection = None
@classmethod
def cleanGlobals(cls):
"""
Clean all class methods used to share info between different requests.
"""
cls.valid_responses = []
cls.first_client = None
def log_message(self, *args):
"""
We don't want any logs from the server.
"""
pass
def do_QUIT(self):
"""
Called by HTTPServerContext to trigger server stop.
"""
self.server.stopped = True
self._debug('QUIT')
self.send_response(200)
self.end_headers()
# Force closing the connection.
self.close_connection = 1
def do_GET(self):
self._handleRequest()
def do_HEAD(self):
self._handleRequest()
def do_POST(self):
self._handleRequest()
def do_PUT(self):
self._handleRequest()
def do_DELETE(self):
self._handleRequest()
def do_PROPFIND(self):
self._handleRequest()
def _handleRequest(self):
"""
Check if we can handle the request and send response.
"""
if not self.first_client:
# Looks like this is the first request so we save the client
# address to compare it later.
self._debug('First: %r' % (self.client_address,))
self.__class__.first_client = self.client_address
response = self._matchResponse()
if response:
self._debug(response)
self._sendResponse(response)
self._debug(
'After: Close-connection: %s' % (self.close_connection,))
return
self.send_error(404)
def _matchResponse(self):
"""
Return the ResponseDefinition for the current request.
"""
for response in self.__class__.valid_responses:
if self.path != response.url or self.command != response.method:
self._debug()
continue
# For POST request we read content.
if self.command != 'GET':
length = int(self.headers.getheader('content-length', 0))
content = self.rfile.read(length)
if content != response.request:
self._debug(content)
continue
# We have a match.
return response
return None
def _debug(self, message=''):
"""
Print to stdout a debug message.
"""
if not self.debug:
return
print('\nGot %s:%s - %s\n' % (
self.command, self.path, message))
def _sendResponse(self, response):
"""
Send response to client.
"""
connection_header = self.headers.getheader('connection')
if connection_header:
connection_header = connection_header.lower()
if self.protocol_version == 'HTTP/1.1':
# For HTTP/1.1 connections are persistent by default.
if not connection_header:
connection_header = 'keep-alive'
else:
# For HTTP/1.0 connections are not persistent by default.
if not connection_header:
connection_header = 'close'
if response.persistent is None:
# Ignore persistent flag.
pass
elif response.persistent:
if connection_header == 'close':
self.send_error(400, 'Headers do not persist the connection')
if self.first_client != self.client_address:
self.send_error(
400,
'Persistent connection not reused. First %r. Now %r' % (
self.first_client, self.client_address))
else:
if connection_header == 'keep-alive':
self.send_error(400, 'Connection was persistent')
if isinstance(response.response_code, Exception):
# Just close the connection without a valid HTTP response.
self.wfile.write(response.response_code.message)
self.close_connection = 1
return
self.send_response(
response.response_code, response.response_message)
self.send_header("Content-Type", response.content_type)
if response.response_length:
self.send_header("Content-Length", response.response_length)
self.end_headers()
self.wfile.write(response.test_response_content)
if not response.response_persistent:
# Force closing the connection as requested
# by response.
self.close_connection = 1
class ResponseDefinition(object):
"""
A class encapsulating the required data for configuring a response
generated by the HTTPServerContext.
It contains the following data:
* url - url that will trigger this response
* request - request that will trigger the response once the url is
matched
* response_content - content of the response
* response_code - HTTP code of the response or an Exception().
If the Exception has any text, it is written to the response as raw
data.
* response_message - Message sent together with HTTP code.
* content_type - Content type of the HTTP response
* response_length - Length of the response body content.
`None` to calculate automatically the length.
`` (empty string) to ignore content-length header.
* persistent: whether the request should persist the connection.
Set to None to ignore persistent checking.
"""
def __init__(
self, url='', request='', method='GET',
response_content=b'', response_code=200, response_message=None,
content_type=b'text/html', response_length=None,
persistent=True, response_persistent=None,
):
self.url = url
self.method = method
self.request = request
if not isinstance(response_content, bytes):
response_content = codecs.encode(response_content, 'utf-8')
self.test_response_content = response_content
self.response_code = response_code
self.response_message = response_message
self.content_type = content_type
if response_length is None:
response_length = len(self.test_response_content)
self.response_length = str(response_length)
self.persistent = persistent
if response_persistent is None:
response_persistent = persistent
self.response_persistent = response_persistent
def __repr__(self):
return 'ResponseDefinition:%s:%s:%s %s:pers-%s' % (
self.url,
self.method,
self.response_code, self.response_message,
self.persistent,
)
def updateResponseContent(self, content):
"""
Will update the content returned to the server.
"""
if not isinstance(content, bytes):
content = codecs.encode(content, 'utf-8')
self.test_response_content = content
response_length = len(self.test_response_content)
self.response_length = str(response_length)
chevah-compat==0.45.2
twisted
nose
bunch
scandir
remote_pdb
unidecode
mock
ld
"""
Test to demonstrate a strange behaviour of SMTPConnectError.
To give it a try.
$ virtualenv build
$ . build/bin/activate
$ pip install -r requirements.txt
$ nosetests test.py
OR
$ trial --force-gc ./test.py
"""
from __future__ import absolute_import, unicode_literals
from contextlib import contextmanager
import os
import sys
# This is so that we can run as trial ./test.py
sys.path.append(os.getcwd())
from bunch import Bunch
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectionClosed
from twisted.internet.protocol import ServerFactory, Protocol
from twisted.web.client import (
Agent,
HTTPConnectionPool,
ResponseNeverReceived,
)
from chevah.compat.testing import ChevahTestCase
from esmtp import ESMTPClient
from http import HTTPServerContext, ResponseDefinition
def serverFactoryForProtocol(protocol_class):
"""
Create a new factory instance for `protocol_class`.
"""
factory = ServerFactory()
factory.protocol = protocol_class
factory.protocol_instance = None
factory.protocol_instances = []
factory.protocol_open_deferred = defer.Deferred()
return factory
class RejectPeerProtocol(Protocol):
"""
A protocol which will reject any connection.
"""
def connectionMade(self):
self.transport.loseConnection()
class TestCase(ChevahTestCase):
"""
There are 2 tests here. If SMTP tests is executed first, the HTTP will
fail as it will get the STMP error.
If HTTP tests is first, all is fine.
"""
@contextmanager
def listenTCP(self, protocol_class, address='127.0.0.1', port=0):
""""
Return a context which open a TCP/IP socket.
At the context exit the listening socket is closed together with
all accepted sockets.
"""
factory = serverFactoryForProtocol(protocol_class)
listening_port = None
try:
listening_port = reactor.listenTCP(
interface=address,
port=port,
factory=factory,
)
local_peer = listening_port.getHost()
context = Bunch(
port=local_peer.port,
address=local_peer.host,
connector=port,
factory=factory,
)
yield context
finally:
for connection in factory.protocol_instances:
connection.loseConnection()
if not listening_port: # pragma: no branch
return
# Wait for the listening port to close.
deferred = listening_port.stopListening()
self.getDeferredResult(deferred)
def getAgent(self):
"""
Return a new PersistentAgent instance.
"""
agent = Agent(
reactor=reactor,
pool=HTTPConnectionPool(
reactor=reactor, persistent=True),
)
self.addCleanup(self.closeAgent, agent)
return agent
def closeAgent(self, agent):
try:
deferred = agent._pool.closeCachedConnections()
self.getDeferredResult(deferred)
finally:
self.executeReactor()
def test_1_send_server_rejection(self):
"""
This will trigger the STMPConnectError
"""
with self.listenTCP(protocol_class=RejectPeerProtocol) as server:
sut = ESMTPClient(
host=server.address, port=server.port, retries=0)
deferred = sut.send(
sender='sender@ingore.com',
recipients=['dst@ignore.com'],
subject='[SUBJ] Ignored',
body='some content ignored',
)
failure = self.getDeferredFailure(deferred, prevent_stop=True)
self.assertStartsWith(
'-1 Unable to connect to server.',
failure.value.message
)
def test_2_makeRequest_agent_wrap(self):
"""
This is a test in which STMPConnectError is a side effect.
HTTP Agent will wrap the error in an ResponseNeverReceived with the
original reason passed as a list.
The failure is unwrapped and the request will get the direct failure.
"""
sut = self.getAgent()
response = ResponseDefinition(
method=b'GET',
url='/endpoint',
request=b'',
# We don't write anything on the response to trigger the behavior
# described in the docstring.
response_code=Exception(b''),
persistent=True,
)
with HTTPServerContext([response]) as server:
url = 'http://%s:%s/endpoint' % (server.ip, server.port)
deferred = sut.request(method=b'GET', uri=url.encode('utf-8'))
failure = self.getDeferredFailure(deferred, prevent_stop=True)
self.assertIsInstance(ResponseNeverReceived, failure.value)
inner_failure = failure.value[0][0]
self.assertIsInstance(ConnectionClosed, inner_failure.value)
def test_3_makeRequest_agent_wrap(self):
"""
This is another test in which STMPConnectError is a side effect.
HTTP Agent will wrap the error in an ResponseNeverReceived with the
original reason passed as a list.
The failure is unwrapped and the request will get the direct failure.
"""
sut = self.getAgent()
response = ResponseDefinition(
method=b'GET',
url='/endpoint',
request=b'',
# We don't write anything on the response to trigger the behavior
# described in the docstring.
response_code=Exception(b''),
persistent=True,
)
with HTTPServerContext([response]) as server:
url = 'http://%s:%s/endpoint' % (server.ip, server.port)
deferred = sut.request(method=b'GET', uri=url.encode('utf-8'))
failure = self.getDeferredFailure(deferred, prevent_stop=True)
self.assertIsInstance(ResponseNeverReceived, failure.value)
inner_failure = failure.value[0][0]
self.assertIsInstance(ConnectionClosed, inner_failure.value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment