Skip to content

Instantly share code, notes, and snippets.

@GaretJax
Last active October 25, 2015 11:43
Show Gist options
  • Save GaretJax/124c523a62ba48c9eec1 to your computer and use it in GitHub Desktop.
Save GaretJax/124c523a62ba48c9eec1 to your computer and use it in GitHub Desktop.
from OpenSSL import SSL as ssl
from zope.interface import implementer
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.internet import defer
from twisted.internet.ssl import CertificateOptions
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.logger import Logger
from .utils import CachingDict
class DummyTransport(object):
"""
Dummy transport ignoring writes and connection drops.
"""
def write(self, bytes):
pass
def loseConnection(self):
pass
class TLSServerNameCallbackHelper(TLSMemoryBIOProtocol, object):
"""
Fake TLSMemoryBIOProtocol to be used until the client hello is received
and the SNI callback can be triggered by the underlying SSL implementation.
"""
def __init__(self, sniHandler, *args, **kwargs):
super(TLSServerNameCallbackHelper, self).__init__(*args, **kwargs)
self._receivedBytes = []
self._sniHandlerCalled = False
self._sniHandler = sniHandler
self._tlsConnection = self._buildDummyConnection()
self.transport = DummyTransport()
def _buildDummyConnection(self):
context = CertificateOptions().getContext()
context.set_tlsext_servername_callback(self._executeServernameCallback)
context.set_info_callback(self._handover)
connection = ssl.Connection(context, None)
connection.set_accept_state()
return connection
def _handover(self, connection, where, ret):
if self._sniHandlerCalled:
return
if where & ssl.SSL_CB_EXIT:
self._sniHandlerCalled = True
self._gotContext(None, connection)
def _gotContext(self, context, connection):
connection.shutdown()
bytes = ''.join(self._receivedBytes)
self._sniHandler.gotContext(bytes, context)
def _gotError(self, failure, connection):
connection.shutdown()
bytes = ''.join(self._receivedBytes)
self._sniHandler.gotError(bytes, failure)
def _executeServernameCallback(self, connection):
assert not self._sniHandlerCalled
self._sniHandlerCalled = True
d = defer.maybeDeferred(
self.factory._connectionCreator.serverContextForSNI, connection)
d.addCallback(self._gotContext, connection)
d.addErrback(self._gotError, connection)
def dataReceived(self, bytes):
self._receivedBytes.append(bytes)
super(TLSServerNameCallbackHelper, self).dataReceived(bytes)
class SNIEnabledTLSMemoryBIOProtocol(TLSMemoryBIOProtocol, object):
"""
TLSMemoryBIOProtocol first sending the client hello to an helper to trigger
the SNI callback returning a deferred, waiting for it to callback and then
replaying the client hello on the real implementation, for which we already
have a context.
"""
log = Logger()
def _replayHandshake(self, bytes):
# Restore the original dataReceived method and replay the received
# bytes on the original connection.
self.dataReceived = self._originalDataReceived
self.dataReceived(bytes)
def gotContext(self, bytes, context):
if context:
self.getHandle().set_context(context)
self._replayHandshake(bytes)
def gotError(self, bytes, failure):
self.log.error('failed to build context', failure=failure)
self.loseConnection()
def makeConnection(self, transport):
# Hook up the dataReceived method from the handshake helper until
# the client hello is received, the SSL implementation parsed the
# SNI extension, and the deferred returned by the SNI calls back.
handshakeHelper = TLSServerNameCallbackHelper(
self, self.factory, self.wrappedProtocol, self._connectWrapped)
self._originalDataReceived = self.dataReceived
self.dataReceived = handshakeHelper.dataReceived
super(SNIEnabledTLSMemoryBIOProtocol, self).makeConnection(transport)
class SNIEnabledTLSMemoryBIOFactory(TLSMemoryBIOFactory):
protocol = SNIEnabledTLSMemoryBIOProtocol
class ISNIEnabledConnectionCreator(IOpenSSLServerConnectionCreator):
def serverContextForSNI(self, connection):
"""
Called when the server name indication is received by the server
(tlsext_servername_callback of pyOpenSSL).
This method can return `None` to not alter the connection context,
a new context instance to be used for the connection, or a deferred
with any of the previous two return values.
The returned context will be set as the context of the connection.
Any context set directly on the `connection` argument (i.e. by using
`Connection.set_context`) will be lost.
"""
class SNIEnabledTLSEndpoint(object):
"""
TLS endpoint with support for returning deferreds from the server name
indication callback.
"""
def __init__(self, endpoint, contextFactory):
assert ISNIEnabledConnectionCreator.providedBy(contextFactory)
self.endpoint = endpoint
self.contextFactory = contextFactory
def listen(self, factory):
return self.endpoint.listen(SNIEnabledTLSMemoryBIOFactory(
self.contextFactory, False, factory
))
# Example usage
from zope.interface import implementer
from twisted.internet import reactor, endpoints
from twisted.web import static, server
@implementer(ISNIEnabledConnectionCreator)
class SNICallbackSSLFactory(object):
def __init__(self, certificate_options):
self.certificate_options = certificate_options
def _makeContext(self):
# NOTE/TODO: Somehow the connections are picky about sharing contexts
# between them. This might no be an issue when different connection
# instances are created for the same session, but it is here because
# we reuse the same context connections initialized with exactly the
# same client hello.
self.certificate_options._context = None
return self.certificate_options.getContext()
def serverContextForSNI(self, connection):
hostname = connection.get_servername()
def build(d):
context = self._makeContext()
context.use_privatekey_file('certs/{}/key.pem'.format(hostname))
context.use_certificate_file('certs/{}/cert.pem'.format(hostname))
d.callback(context)
d = defer.Deferred()
reactor.callLater(1, build, d)
return d
def serverConnectionForTLS(self, tlsProtocol):
return ssl.Connection(self._makeContext(), None)
server_factory = server.Site(static.Data('Hello world!', 'text/plain'))
ssl_context_factory = SNICallbackSSLFactory(CertificateOptions())
tcp_endpoint = endpoints.TCP4ServerEndpoint(reactor, 443)
tls_endpoint = SNIEnabledTLSEndpoint(tcp_endpoint, ssl_context_factory)
tls_endpoint.listen(server_factory)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment