Skip to content

Instantly share code, notes, and snippets.

@adiroiban
Created December 11, 2020 16:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adiroiban/20f938db677f66da5de2c37a8b3a3fd9 to your computer and use it in GitHub Desktop.
Save adiroiban/20f938db677f66da5de2c37a8b3a3fd9 to your computer and use it in GitHub Desktop.
Some twisted proto helpers
# Copyright (c) 2012 Adi Roiban.
# See LICENSE for details.
"""
Protocol to help with tests.
This comes in addition to standard twisted.test.proto_helpers
"""
from io import BytesIO
from StringIO import StringIO
from bunch import Bunch
from mock import patch
from OpenSSL import SSL
from twisted.internet import address, defer, protocol
from twisted.internet.abstract import _ConsumerMixin
from twisted.internet.error import ConnectionAborted, ConnectError
from twisted.internet.protocol import ServerFactory, Protocol
from twisted.internet.task import Clock
from twisted.internet.tcp import Connector, Port
from twisted.internet.ssl import (
DefaultOpenSSLContextFactory,
ClientContextFactory,
)
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.internet._newtls import (
ConnectionMixin,
ClientMixin,
ServerMixin,
startTLS,
)
from twisted.protocols import basic, loopback
from twisted.python.failure import Failure
from twisted.test.proto_helpers import (
StringTransportWithDisconnection as
TwistedStringTransportWithDisconnection,
)
class AuditingProtocol(Protocol, object):
"""
A protocol which records its events.
"""
def __init__(self):
self.connection_made = []
self.connection_lost = []
self.data_received = []
# Called when connection is lost.
self.lost_deferred = defer.Deferred()
def connectionMade(self):
"""
See: Protocol.
"""
self.connection_made.append(True)
def dataReceived(self, line):
"""
See: Protocol.
"""
self.data_received.append(line)
def connectionLost(self, reason):
"""
See: Protocol.
"""
text_reason = str(reason.value)
self.connection_lost.append(text_reason)
self.lost_deferred.callback(text_reason)
class EchoProtocol(Protocol):
"""
A protocol which replies what it receives, and keep track of it.
"""
def __init__(self):
self.value = []
def dataReceived(self, data):
"""
Called when data is received.
"""
self.value.append(data)
self.transport.write(data)
class SSLInspectorProtocol(Protocol):
"""
Force the handshake, get the remote certificate / chain and other SSL/TLS
related info, and then close the connection.
"""
# The peer certificate and chain as obtained after the handshake.
peer_certificate = None
peer_chain = None
# Deferred which is fired when we got the peer details.
_handshake_done = None
def triggerHandshake(self):
"""
Send some data to trigger the handshake finalization.
"""
self.transport.write('Some data\r\n')
self._handshake_done = defer.Deferred()
return self._handshake_done
def connectionLost(self, reason):
"""
Called
"""
if self._handshake_done.called:
return
self._handshake_done.errback(reason)
def dataReceived(self, data):
"""
Called when data was received.
By this time we should have a finalized handshake.
"""
self.peer_certificate = self.transport.getPeerCertificate()
self.peer_chain = self.transport._tlsConnection.get_peer_cert_chain()
# Close the connection once we are done.
self.transport.loseConnection()
self._handshake_done.callback(None)
class AccumulatingLineProtocol(basic.LineReceiver):
"""
Stored each received line and can fire a deferred on connection and
disconnection
factory.protocol_open_deferred is fired on connection made.
protocol.close_deferred is fired on connection close.
line_received_deferred deferred is called for each new line which is
received, creating a new deferred after each line.
It injects the protocol instance as a factory member.
"""
# Protocol connection indicators.
connection_made = False
connection_close = False
# The reason why connection was closed.
close_reason = None
# An optional deferred called on connection done.
close_deferred = None
line_received_deferred = None
# Remote peer as available at connection time.
peer = None
# Lines received so far.
lines = None
factory = None
def connectionMade(self):
"""
See: Protocol.
"""
self.lines = []
self.connection_made = True
if self.factory.protocol_open_deferred is not None:
self.factory.protocol_open_deferred.callback(self)
self.line_received_deferred = defer.Deferred()
self.peer = self.transport.getPeer()
def lineReceived(self, line):
"""
See: Protocol.
"""
self.lines.append(line)
self.line_received_deferred.callback(line)
self.line_received_deferred = defer.Deferred()
def connectionLost(self, reason):
"""
See: Protocol.
"""
self.connection_close = True
self.close_reason = reason
if self.close_deferred is not None:
self.close_deferred.callback(None)
self.close_deferred = None
class AccumulatingDatagramServerProtocol(protocol.DatagramProtocol):
"""
A datagram protocol used for accumulating all received data as a server.
It has a set of deferred which can be used for waiting for various
connection events.
"""
def __init__(self):
self.start_deferred = defer.Deferred()
self.stop_deferred = defer.Deferred()
self.received_deferred = defer.Deferred()
self.started = False
self.stopped = False
self.client_address = False
self.received_data = []
def stopProtocol(self):
"""
See: DatagramProtocol.
"""
self.stopped = True
self.stop_deferred.callback(None)
def startProtocol(self):
"""
See: DatagramProtocol.
"""
self.started = True
self.start_deferred.callback(None)
def datagramReceived(self, data, addr):
"""
See: DatagramProtocol.
"""
self.client_address = addr
self.received_data.append(data)
self.received_deferred.callback(data)
class StringTransportWithDisconnection(
TwistedStringTransportWithDisconnection, object):
"""
Transport over StringIO.
"""
def abortConnection(self):
"""
Abort the transport.
"""
# FIXME:1370:
# Check if fix is included in latest Twisted release and remove this
# patch.
# https://twistedmatrix.com/trac/ticket/8161
return self._closeConnection(ConnectionAborted())
def _closeConnection(self, reason):
"""
Common code for closing the connection.
"""
# This is here since the Twisted implementation does not have
# this method.
if not self.connected: # noqa:cover
return
self.connected = False
self.protocol.connectionLost(Failure(reason))
class StringTLSTransport(StringTransportWithDisconnection):
"""
FIXME:3600:
DEPRECATED: String transport with TLS.
"""
context = None
def __init__(self, certificate=None):
super(StringTLSTransport, self).__init__()
self.test_peer_certificate = certificate
self.TLS = False
self.transport = self
def startTLS(self, context):
self.context = context
self.TLS = True
self._tlsConnection = SSL.Connection(context.getContext(), None)
self.protocol._tlsConnection = self._tlsConnection
def stopTLS(self):
self.TLS = False
self.protocol._tlsConnection = None
def getPeerCertificate(self):
return self.test_peer_certificate
class _StringSTARTTLSTransport(
StringTransportWithDisconnection, ConnectionMixin):
"""
String transport with TLS start/stop capabilities for both
client and server side.
"""
def __init__(self, certificate=None):
super(_StringSTARTTLSTransport, self).__init__()
# A reference to the last context used.
self.context = None
# Data written over TLS.
self._tls_io = None
# Fake the peer certificate.
self.test_peer_certificate = certificate
# Reference to the last TLS protected protocol.
self._tls_protocol = None
def startTLS(self, context, normal=True):
"""
Switch the transport from clear to secure mode.
"""
if self.context is not None:
raise AssertionError('SSL/TLS already started.')
self.context = context
startTLS(self, context, normal, _StringSTARTTLSTransport)
# Keep a copy of the tls protocol so that we can fake its
# shutdown after stop tls.
self._tls_protocol = self.protocol
# We don't have a real peer, so shutdown will always fail.
# Here we pretend that all is ok.
try:
self.protocol._tlsConnection.shutdown = lambda: None
self.protocol._tlsConnection.get_peer_certificate = (
lambda: self.test_peer_certificate)
except AttributeError:
# On PyOpenSSL 0.13 OpenSSL.SSL.Connection is a C object so we
# do a more aggressive mocking.
# The Twisted API is using both bio_* and non bio version in both
# client and server side operations.
def recv(length):
"""
Called when consumer want data from the connection.
"""
# Just signal that shutdown is complete.
raise SSL.ZeroReturnError()
self.protocol._tlsConnection = Bunch(
shutdown=lambda: None,
bio_shutdown=lambda: None,
recv=recv,
bio_read=lambda length: b'',
get_peer_certificate=lambda: self.test_peer_certificate,
)
def finalizeTLSShutdown(self):
"""
Fake the finalization of TLS shutdown as received from the
remote peer.
"""
self._tls_protocol._tlsShutdownDeferred.callback(None)
def tls_clear(self):
"""
Clear the data sent over TLS.
"""
if not self._tls_io:
return
self._tls_io = BytesIO()
def tls_value(self):
"""
Return the clear text data as it would have been written
over a TLS/SSL protected channel.
"""
if not self._tls_io:
return b''
return self._tls_io.getvalue()
def write(self, bytes):
"""
Write the bytes.
"""
if self.TLS:
if self._tls_io is None:
# First time, we write the handshake, but we ignore it
# for the purpose of the test as we only care about the
# payload data.
self._tls_io = BytesIO()
else:
self._tls_io.write(bytes)
else:
self.io.write(bytes)
def getPeerCertificate(self):
"""
Return the certificate of the remote peer.
"""
return self.test_peer_certificate
class StringSTARTTLSClientTransport(_StringSTARTTLSTransport, ClientMixin):
"""
A transport as used by the client-side connection which support
STARTTLS.
"""
class StringSTARTTLSServerTransport(_StringSTARTTLSTransport, ServerMixin):
"""
A transport as used by the server-side connection.
"""
def _start_tls(klass, context_factory, protocol, certificate=None):
"""
To reuse the code we are using the STARTTLS logic for firing the
secure connection.
"""
base_transport = klass(certificate=certificate)
base_transport.protocol = protocol
base_transport.startTLS(context_factory)
# But the transport is then updated to look like one which was not
# started with STARTTLS.
base_transport._tlsConnection = base_transport.protocol._tlsConnection
base_transport.protocol = None
return base_transport
def StringTLSClientTransport(context_factory, protocol, certificate=None):
"""
A transport as used by the client-side connection which is already
secured by TLS without STARTLS
"""
return _start_tls(
klass=StringSTARTTLSClientTransport,
context_factory=context_factory,
certificate=certificate,
protocol=protocol,
)
def StringTLSServerTransport(context_factory, protocol, certificate=None):
"""
A transport as used by the server-side connection.
"""
return _start_tls(
klass=StringSTARTTLSServerTransport,
context_factory=context_factory,
certificate=certificate,
protocol=protocol,
)
class InMemoryConsumer(_ConsumerMixin):
"""
A consumer which keeps all data in memory.
"""
connected = True
disconnecting = False
disconnected = False
def __init__(self, data=None):
if data is None:
data = StringIO()
self._data = data
def registerProducer(self, producer, streaming):
result = super(InMemoryConsumer, self).registerProducer(
producer, streaming)
# Trigger the producer right away.
producer.resumeProducing()
return result
def write(self, data):
"""
Accumulate data.
"""
self._data.write(data)
def value(self):
"""
Return accumulated data so far.
"""
return self._data.getvalue()
@property
def isConnected(self):
return self.connected
class StreamPullProducer(object):
"""
A pull producer to `consumer` for the content of `file`.
"""
# The chunk is big enough so that it will read most data from one call.
CHUNK_SIZE = 8092
def __init__(self, consumer, file):
self.consumer = consumer
self.file = file
self.deferred = None
def resumeProducing(self):
chunk = ''
if self.file:
chunk = self.file.read(self.CHUNK_SIZE)
if not chunk:
# We are at EOF.
self.file = None
self.consumer.unregisterProducer()
if self.deferred:
self.deferred.callback(None)
self.deferred = None
return
self.consumer.write(chunk)
def pauseProducing(self):
pass
def stopProducing(self):
if self.deferred:
self.deferred.errback(
Exception("Consumer asked us to stop producing"))
self.deferred = None
class ConnectionTrackingServerFactory(ServerFactory, object):
"""
A factory which will track its connections.
"""
def __init__(self):
# Latest connection.
self.protocol_instance = None
# All connections.
self.protocol_instances = []
# A deferred which can be called when the protocol when connected.
self.protocol_open_deferred = defer.Deferred()
def buildProtocol(self, addr):
protocol = super(ConnectionTrackingServerFactory, self).buildProtocol(
addr)
self.protocol_instance = protocol
self.protocol_instances.append(protocol)
return protocol
def serverFactoryForProtocol(protocol_class):
"""
Create a new factory instance for `protocol_class`.
"""
factory = ConnectionTrackingServerFactory()
factory.protocol = protocol_class
return factory
# FIXME:1370:
# Patch the loopback code to support abortConnection as our forked version
# don't support it.
loopback._LoopbackTransport.abortConnection = (
lambda self: self.loseConnection())
loopback._LoopbackTransport.pauseProducing = lambda self: None
loopback._LoopbackTransport.resumeProducing = lambda self: None
class InMemorySocket(Bunch):
"""
A socket which will not touch the network.
"""
def __init__(self, host='127.0.0.1', port=4224):
self._host = host
self._port = port
self.fileno = port
def close(self):
"""
No-operation as there is nothing to close.
See socket.close.
"""
def getsockname(self):
"""
Return our/local side of the socket.
See socket.getsockname.
"""
return (self._host, self._port)
def setblocking(self, flag):
"""
Does nothing as we don't have custom logic for
blocking vs non-blocking.
"""
def recv(self, bufsize):
"""
No data is transferred over this socket, as the code should use the
high level transport and use Protocol.dataReceived.
"""
return b''
class InMemoryReactorAbstract(Clock, object):
"""
A simple reactor which connects without touching the network using
a client side endpoint.
It is initialized with a list of (address, port) tuple for which
connections are allowed.
For client connection the list of tuple can be a
(address, port, make_connection) for which when `make_connection` is
False it will not trigger the connection right away.
At this point it can not be used for connecting a client and a server using
the same reactor.
"""
def __init__(self, expected_addresses, client_transport_factory=None):
super(InMemoryReactorAbstract, self).__init__()
# Delay the import as mk is also importing proto_helpers.
from chevah.server.testing import mk
self._mk = mk
self._expected_addresses = expected_addresses[:]
self.latest_protocol = None
self._clientTransportFactory = client_transport_factory
# Ports listening for conenctions
self._ports = {}
def addReader(self, port):
"""
Called when we are waiting for incoming connections.
"""
# Keep a reference, in case we want to initiate a client connection.
self._ports[port.port] = port
def removeReader(self, port):
"""
Called when we are no longer waiting for incoming connections.
"""
try:
del self._ports[port.port]
except Exception:
# Might be a client port, and we don't keep a record of these
# ports... or it might be a port which is not listening yet.
pass
def removeWriter(self, port):
"""
Called when we should remove a connection from the reactor loop.
Does nothing as we don't keep a record of writers.
"""
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
"""
Do or reject the connection based on reactor configuration.
"""
raise NotImplementedError('connectTCP not implemented.')
def connectSSL(self, host, port, factory, timeout=30, bindAddress=None):
"""
Do or reject the connection based on reactor configuration.
"""
raise NotImplementedError('connectSSL not implemented.')
def listenTCP(self, port, factory, backlog=50, interface=''):
"""
Set up a listening port.
Code copied from Twisted, with the exception of the fake socket
injection.
"""
p = Port(
port=port,
factory=factory,
backlog=backlog,
interface=interface,
reactor=self,
)
# We inject our own socket to not touch the network.
p._preexistingSocket = InMemorySocket()
p.startListening()
return p
def triggerConnectionWithClose(
self,
server_protocol, server_port,
client_protocol, client_port=1234,
server_host='127.0.0.1',
client_host='1.2.3.4',
):
"""
Rig a client connection to the server listening on `port` and using
the `client_protocol` to communicate on the client side.
Will return a deferred which is fired when the connection is closed.
"""
# Default loopbackAsync code does not allow injecting the peers into
# the transport.
server_to_client = loopback._LoopbackQueue()
client_to_server = loopback._LoopbackQueue()
server_address = address.IPv4Address('TCP', server_host, server_port)
client_address = address.IPv4Address('TCP', client_host, client_port)
server_transport = loopback._LoopbackTransport(server_to_client)
server_transport.getPeer = lambda: client_address
server_transport.getHost = lambda: server_address
server_transport.socket = InMemorySocket(
host=server_host, port=server_host)
client_transport = loopback._LoopbackTransport(client_to_server)
client_transport.getPeer = lambda: server_address
client_transport.getHost = lambda: client_address
client_transport.socket = InMemorySocket(
host=client_host, port=client_host)
server_protocol.makeConnection(server_transport)
client_protocol.makeConnection(client_transport)
defered = loopback._loopbackAsyncBody(
server=server_protocol,
serverToClient=server_to_client,
client=client_protocol,
clientToServer=client_to_server,
pumpPolicy=loopback.identityPumpPolicy,
)
self.latest_protocol = server_protocol
return defered
def triggerClientConnectionWithClose(self, port, client_protocol):
"""
Rig a client connection to the server listening on `port` and using
the `client_protocol` to communicate on the client side.
Will return a deferred which is fired when the connection is closed.
"""
try:
server_port = self._ports[port]
except KeyError: # noqa:cover
raise AssertionError(
'No server in this reactor is listening to %s.' % (port,))
server_protocol = server_port.factory.buildProtocol(('1.2.3.4', 1234))
return self.triggerConnectionWithClose(
server_protocol=server_protocol,
server_port=0,
client_protocol=client_protocol,
client_port=1234,
)
def triggerTLSClientConnectionWithClose(self, port, client_protocol):
"""
Rig a TLS client connection similar to
triggerClientConnectionWithClose.
"""
tls_factory = TLSMemoryBIOFactory(ClientContextFactory(), True, None)
tls_protocol = TLSMemoryBIOProtocol(
tls_factory, client_protocol, _connectWrapped=True)
return self.triggerClientConnectionWithClose(
port=port, client_protocol=tls_protocol)
def tls_wrap_server_protocol(protocol):
"""
Wrap the `protocol` into a server side TLS protocol.
"""
def get_context(protocol):
from chevah.server.testing import OPENSSL_SECLEVEL
context = SSL.Context(protocol)
context.set_cipher_list('ALL' + OPENSSL_SECLEVEL)
return context
context = DefaultOpenSSLContextFactory(
privateKeyFileName='test_data/pki/server-cert-and-key-2048.pem',
certificateFileName='test_data/pki/server-cert-and-key-2048.pem',
_contextFactory=get_context
)
tls_factory = TLSMemoryBIOFactory(context, True, None)
return TLSMemoryBIOProtocol(
tls_factory, protocol, _connectWrapped=True)
class InMemoryTCPReactor(InMemoryReactorAbstract):
"""
A simple reactor which implements only TCP connections.
"""
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
"""
Do or reject the connection based on reactor configuration.
"""
try:
expected_address = self._expected_addresses.pop(0)
except Exception: # noqa:cover
factory.clientConnectionFailed(
self, ConnectError(
osError=None, string='!!!FAIL!!! Unexpected connection'))
return
try:
expected_host, expected_port, make_connection = expected_address
except ValueError:
make_connection = True
expected_host, expected_port = expected_address
if (host, port) == (expected_host, expected_port):
protocol = factory.buildProtocol((host, port))
if make_connection:
# Use a simple string transport to fake a non-existent
# server connection.
if self._clientTransportFactory is None:
transport = self._mk.makeStringTransportWithDisconnection()
else:
transport = self._clientTransportFactory(protocol)
protocol.makeConnection(transport)
transport.protocol = protocol
self.latest_transport = transport
else:
factory.clientConnectionFailed(
self, ConnectError(
osError=None, string='!!!FAIL!!! unknown address/port'))
return
self.latest_protocol = protocol
# Make it similar to Twisted code, but use a fake socket.
connector = Connector(
host, port, factory, timeout, bindAddress, reactor=self)
# Rig the socket creation so that we don't touch the network.
with patch(
'twisted.internet.tcp.Client.createInternetSocket',
return_value=InMemorySocket(),
):
connector.connect()
return connector
class InMemorySTARTTLSReactor(InMemoryReactorAbstract):
"""
A simple reactor which implements only TCP connections which can be later
upgraded to TLS/SSL using STARTTLS.
"""
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
"""
Do or reject the connection based on reactor configuration.
"""
try:
expected_address = self._expected_addresses.pop(0)
except Exception: # noqa:cover
factory.clientConnectionFailed(
self, ConnectError(
osError=None, string='!!!FAIL!!! Unexpected connection'))
return
if (host, port) == expected_address:
protocol = factory.buildProtocol((host, port))
if self._clientTransportFactory is None:
transport = StringSTARTTLSServerTransport()
else:
transport = self._clientTransportFactory(protocol)
protocol.makeConnection(transport)
transport.protocol = protocol
self.latest_protocol = protocol
self.latest_transport = transport
else:
factory.clientConnectionFailed(
self, ConnectError(
osError=None, string='!!!FAIL!!! unknown address/port'))
class InMemorySSLReactor(InMemoryReactorAbstract):
"""
A simple reactor which connects only with SSL.
It is initialized with a list of (address, port) tuple for which
connections are allowed.
"""
def connectSSL(
self, host, port, factory, contextFactory, timeout=30,
bindAddress=None,
):
"""
Do or reject the connection based on reactor configuration.
"""
try:
expected_address = self._expected_addresses.pop(0)
except Exception:
factory.clientConnectionFailed(
self, ConnectError(
osError=None, string='!!!FAIL!!! Unexpected connection'))
return
if (host, port) == expected_address:
protocol = factory.buildProtocol((host, port))
if self._clientTransportFactory is None:
transport = self._mk.makeStringTransportWithDisconnection()
else:
transport = self._clientTransportFactory(protocol)
protocol.makeConnection(transport)
transport.protocol = protocol
self.latest_protocol = protocol
self.latest_context_factory = contextFactory
self.latest_transport = transport
else:
factory.clientConnectionFailed(
self, ConnectError(
osError=None, string='!!!FAIL!!! unknown address/port'))
def listenSSL(
self, port, factory, contextFactory, backlog=50, interface=''):
"""
Taken from Twisted code.
"""
tlsFactory = TLSMemoryBIOFactory(contextFactory, False, factory)
port = self.listenTCP(port, tlsFactory, backlog, interface)
port._type = 'TLS'
self._ports[port] = port
return port
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment