Skip to content

Instantly share code, notes, and snippets.

@pitrou
Created October 5, 2014 22:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pitrou/f04fa9cbfec88cc37050 to your computer and use it in GitHub Desktop.
Save pitrou/f04fa9cbfec88cc37050 to your computer and use it in GitHub Desktop.
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -530,14 +530,19 @@ class _SelectorTransport(transports._Flo
try:
self._protocol.connection_lost(exc)
finally:
- self._sock.close()
- self._sock = None
- self._protocol = None
- self._loop = None
- server = self._server
- if server is not None:
- server._detach()
- self._server = None
+ self._finalize_connection()
+
+ def _finalize_connection(self):
+ """Last steps to low-level close and unregister this transport.
+ """
+ self._sock.close()
+ self._sock = None
+ self._protocol = None
+ self._loop = None
+ server = self._server
+ if server is not None:
+ server._detach()
+ self._server = None
def get_write_buffer_size(self):
return len(self._buffer)
@@ -667,7 +672,228 @@ class _SelectorSocketTransport(_Selector
return True
-class _SelectorSslTransport(_SelectorTransport):
+
+class SSLPipe(object):
+ """An SSL "Pipe".
+
+ An SSL pipe allows you to communicate with an SSL/TLS protocol instance
+ through memory buffers. It can be used to implement a security layer for an
+ existing connection where you don't have access to the connection's file
+ descriptor, or for some reason you don't want to use it.
+
+ An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
+ data is passed through untransformed. In wrapped mode, application level
+ data is encrypted to SSL record level data and vice versa. The SSL record
+ level is the lowest level in the SSL protocol suite and is what travels
+ as-is over the wire.
+
+ An SslPipe initially is in "unwrapped" mode. To start SSL, call
+ :meth:`do_handshake`. To shutdown SSL again, call :meth:`unwrap`.
+ """
+
+ bufsize = 65536
+
+ # This previously used a socketpair to communicate with the SSL protocol
+ # instance but since October 2014 we're using a Memory BIO! This is
+ # cleaner, and more reliable on Windows. See for example issue #12 for more
+ # details.
+
+ S_UNWRAPPED, S_DO_HANDSHAKE, S_WRAPPED, S_SHUTDOWN = range(4)
+
+ def __init__(self, context, server_side, server_hostname=None):
+ """
+ The *context* argument specifies the :class:`ssl.SSLContext` to use.
+ It is recommended to use :func:`~gruvi.ssl.create_ssl_context` so that
+ it will work on all supported Python versions.
+
+ The *server_side* argument indicates whether this is a server side or
+ client side transport.
+
+ The optional *server_hostname* argument can be used to specify the
+ hostname you are connecting to. You may only specify this parameter if
+ the _ssl module supports Server Name Indication (SNI).
+ """
+ self._context = context
+ self._server_side = server_side
+ self._server_hostname = server_hostname
+ self._state = self.S_UNWRAPPED
+ self._incoming = ssl.MemoryBIO()
+ self._outgoing = ssl.MemoryBIO()
+ self._sslobj = None
+ self._need_ssldata = False
+
+ @property
+ def context(self):
+ """The SSL context passed to the constructor."""
+ return self._context
+
+ @property
+ def ssl_object(self):
+ """The internal :class:`ssl.SSLObject` instance."""
+ return self._sslobj
+
+ @property
+ def need_ssldata(self):
+ """Whether more record level data is needed to complete a handshake
+ that is currently in progress."""
+ return self._need_ssldata
+
+ @property
+ def wrapped(self):
+ """Whether a security layer is currently in effect."""
+ return self._state == self.S_WRAPPED
+
+ def do_handshake(self, callback=None):
+ """Start the SSL handshake. Return a list of ssldata.
+
+ The optional *callback* argument can be used to install a callback that
+ will be called when the handshake is complete. The callback will be
+ called with None if successful, else an exception instance.
+ """
+ if self._state != self.S_UNWRAPPED:
+ raise RuntimeError('handshake in progress or completed')
+ wrapargs = ()
+ self._sslobj = self._context.wrap_bio(
+ self._incoming, self._outgoing,
+ server_side=self._server_side,
+ server_hostname=self._server_hostname)
+ self._state = self.S_DO_HANDSHAKE
+ self._on_handshake_complete = callback
+ ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
+ assert len(appdata) == 0
+ return ssldata
+
+ def shutdown(self, callback=None):
+ """Start the SSL shutdown sequence. Return a list of ssldata.
+
+ The optional *callback* argument can be used to install a callback that
+ will be called when the shutdown is complete. The callback will be
+ called without arguments.
+ """
+ if self._state == self.S_UNWRAPPED:
+ raise RuntimeError('no security layer present')
+ self._state = self.S_SHUTDOWN
+ self._on_handshake_complete = callback
+ ssldata, appdata = self.feed_ssldata(b'')
+ assert appdata == [] or appdata == [b'']
+ return ssldata
+
+ def feed_eof(self):
+ """Send a potentially "ragged" EOF.
+
+ This method will raise an SSL_ERROR_EOF exception if the EOF is
+ unexpected.
+ """
+ self._incoming.write_eof()
+ ssldata, appdata = self.feed_ssldata(b'')
+ assert appdata == [] or appdata == [b'']
+
+ def feed_ssldata(self, data, only_handshake=False):
+ """Feed SSL record level data into the pipe.
+
+ The data must be a bytes instance. It is OK to send an empty bytes
+ instance. This can be used to get ssldata for a handshake initiated by
+ this endpoint.
+
+ Return a (ssldata, appdata) tuple. The ssldata element is a list of
+ buffers containing SSL data that needs to be sent to the remote SSL.
+
+ The appdata element is a list of buffers containing plaintext data that
+ needs to be forwarded to the application. The appdata list may contain
+ an empty buffer indicating an SSL "close_notify" alert. This alert must
+ be acknowledged by calling :meth:`shutdown`.
+ """
+ if self._state == self.S_UNWRAPPED:
+ # If unwrapped, pass plaintext data straight through.
+ return ([], [data] if data else [])
+ ssldata = []; appdata = []
+ self._need_ssldata = False
+ if data:
+ self._incoming.write(data)
+ try:
+ if self._state == self.S_DO_HANDSHAKE:
+ # Call do_handshake() until it doesn't raise anymore.
+ self._sslobj.do_handshake()
+ self._state = self.S_WRAPPED
+ if self._on_handshake_complete:
+ self._on_handshake_complete(None)
+ if only_handshake:
+ return (ssldata, appdata)
+ if self._state == self.S_WRAPPED:
+ # Main state: read data from SSL until close_notify
+ while True:
+ chunk = self._sslobj.read(self.bufsize)
+ appdata.append(chunk)
+ if not chunk: # close_notify
+ break
+ if self._state == self.S_SHUTDOWN:
+ # Call shutdown() until it doesn't raise anymore.
+ self._sslobj.unwrap()
+ self._sslobj = None
+ self._state = self.S_UNWRAPPED
+ if self._on_handshake_complete:
+ self._on_handshake_complete()
+ if self._state == self.S_UNWRAPPED:
+ # Drain possible plaintext data after close_notify.
+ appdata.append(self._incoming.read())
+ except (ssl.SSLError, ssl.CertificateError) as e:
+ if getattr(e, 'errno', -1) not in (
+ ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
+ ssl.SSL_ERROR_SYSCALL):
+ if self._state == self.S_DO_HANDSHAKE and self._on_handshake_complete:
+ self._on_handshake_complete(e)
+ raise
+ self._need_ssldata = e.errno == ssl.SSL_ERROR_WANT_READ
+ # Check for record level data that needs to be sent back.
+ # Happens for the initial handshake and renegotiations.
+ if self._outgoing.pending:
+ ssldata.append(self._outgoing.read())
+ return (ssldata, appdata)
+
+ def feed_appdata(self, data, offset=0):
+ """Feed plaintext data into the pipe.
+
+ Return an (ssldata, offset) tuple. The ssldata element is a list of
+ buffers containing record level data that needs to be sent to the
+ remote SSL instance. The offset is the number of plaintext bytes that
+ were processed, which may be less than the length of data.
+
+ NOTE: In case of short writes, this call MUST be retried with the SAME
+ buffer passed into the *data* argument (i.e. the ``id()`` must be the
+ same). This is an OpenSSL requirement. A further particularity is that
+ a short write will always have offset == 0, because the _ssl module
+ does not enable partial writes. And even though the offset is zero,
+ there will still be encrypted data in ssldata.
+ """
+ if self._state == self.S_UNWRAPPED:
+ # pass through data in unwrapped mode
+ return ([data[offset:]] if offset < len(data) else [], len(data))
+ ssldata = []
+ view = memoryview(data)
+ while True:
+ self._need_ssldata = False
+ try:
+ if offset < len(view):
+ offset += self._sslobj.write(view[offset:])
+ except ssl.SSLError as e:
+ # It is not allowed to call write() after unwrap() until the
+ # close_notify is acknowledged. We return the condition to the
+ # caller as a short write.
+ if get_reason(e) == 'PROTOCOL_IS_SHUTDOWN':
+ e.errno = ssl.SSL_ERROR_WANT_READ
+ if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
+ ssl.SSL_ERROR_SYSCALL):
+ raise
+ self._need_ssldata = e.errno == ssl.SSL_ERROR_WANT_READ
+ # See if there's any record level data back for us.
+ if self._outgoing.pending:
+ ssldata.append(self._outgoing.read())
+ if offset == len(view) or self._need_ssldata:
+ break
+ return (ssldata, offset)
+
+
+class _SelectorSslTransport(_SelectorSocketTransport):
_buffer_factory = bytearray
@@ -697,48 +923,65 @@ class _SelectorSslTransport(_SelectorTra
sslcontext.set_default_verify_paths()
sslcontext.verify_mode = ssl.CERT_REQUIRED
- wrap_kwargs = {
- 'server_side': server_side,
- 'do_handshake_on_connect': False,
- }
- if server_hostname and not server_side and ssl.HAS_SNI:
- wrap_kwargs['server_hostname'] = server_hostname
- sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs)
+ if not server_hostname or server_side or not ssl.HAS_SNI:
+ server_hostname = None
- super().__init__(loop, sslsock, protocol, extra, server)
+ _SelectorTransport.__init__(self, loop, rawsock, protocol, extra, server)
+ self._waiter = waiter
+ self._eof = False
+ self._paused = False
+ self._loop.add_reader(self._sock_fd, self._read_ready)
self._server_hostname = server_hostname
- self._waiter = waiter
self._sslcontext = sslcontext
- self._paused = False
-
+ self._sslpipe = SSLPipe(sslcontext, server_side, server_hostname)
+ # App data write buffering (_SelectorSocketTransport handles SSL data).
+ self._write_backlog = []
+ self._write_buffer_size = 0
# SSL-specific extra info. (peercert is set later)
self._extra.update(sslcontext=sslcontext)
+ self._start_handshake()
+
+ def _start_handshake(self):
+ if self._closing:
+ raise TransportError('SSL transport is closing/closed')
if self._loop.get_debug():
logger.debug("%r starts SSL handshake", self)
- start_time = self._loop.time()
+ self._handshake_start_time = self._loop.time()
else:
- start_time = None
- self._on_handshake(start_time)
+ self._handshake_start_time = None
+ self._in_handshake = True
+ self._write_backlog.append([b'', True])
+ self._write_buffer_size += 1
+ self._process_write_backlog()
- def _on_handshake(self, start_time):
+ def _on_handshake_complete(self, handshake_exc):
+ self._in_handshake = False
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.remove_writer(self._sock_fd)
+ sslobj = self._sslpipe.ssl_object
+
+ peercert = None if handshake_exc else sslobj.getpeercert()
+
try:
- self._sock.do_handshake()
- except ssl.SSLWantReadError:
- self._loop.add_reader(self._sock_fd,
- self._on_handshake, start_time)
- return
- except ssl.SSLWantWriteError:
- self._loop.add_writer(self._sock_fd,
- self._on_handshake, start_time)
- return
+ if handshake_exc is not None:
+ raise handshake_exc
+ if not hasattr(self._sslcontext, 'check_hostname'):
+ # Verify hostname if requested, Python 3.4+ uses check_hostname
+ # and checks the hostname in do_handshake()
+ if (self._server_hostname and
+ self._sslcontext.verify_mode != ssl.CERT_NONE):
+ ssl.match_hostname(peercert, self._server_hostname)
except BaseException as exc:
if self._loop.get_debug():
- logger.warning("%r: SSL handshake failed",
- self, exc_info=True)
- self._loop.remove_reader(self._sock_fd)
- self._loop.remove_writer(self._sock_fd)
+ if isinstance(exc, ssl.CertificateError):
+ logger.warning("%r: SSL handshake failed "
+ "on verifying the certificate",
+ self, exc_info=True)
+ else:
+ logger.warning("%r: SSL handshake failed",
+ self, exc_info=True)
self._sock.close()
if self._waiter is not None:
self._waiter.set_exception(exc)
@@ -747,160 +990,140 @@ class _SelectorSslTransport(_SelectorTra
else:
raise
- self._loop.remove_reader(self._sock_fd)
- self._loop.remove_writer(self._sock_fd)
-
- peercert = self._sock.getpeercert()
- if not hasattr(self._sslcontext, 'check_hostname'):
- # Verify hostname if requested, Python 3.4+ uses check_hostname
- # and checks the hostname in do_handshake()
- if (self._server_hostname and
- self._sslcontext.verify_mode != ssl.CERT_NONE):
- try:
- ssl.match_hostname(peercert, self._server_hostname)
- except Exception as exc:
- if self._loop.get_debug():
- logger.warning("%r: SSL handshake failed "
- "on matching the hostname",
- self, exc_info=True)
- self._sock.close()
- if self._waiter is not None:
- self._waiter.set_exception(exc)
- return
-
# Add extra info that becomes available after handshake.
self._extra.update(peercert=peercert,
- cipher=self._sock.cipher(),
- compression=self._sock.compression(),
+ cipher=sslobj.cipher(),
+ compression=sslobj.compression(),
)
- self._read_wants_write = False
- self._write_wants_read = False
+ self._handshake_successful = True
+ self._protocol.connection_made(self)
+ if self._waiter is not None:
+ self._waiter._set_result_unless_cancelled(None)
+
self._loop.add_reader(self._sock_fd, self._read_ready)
- self._loop.call_soon(self._protocol.connection_made, self)
- if self._waiter is not None:
- # wait until protocol.connection_made() has been called
- self._loop.call_soon(self._waiter._set_result_unless_cancelled,
- None)
if self._loop.get_debug():
- dt = self._loop.time() - start_time
+ dt = self._loop.time() - self._handshake_start_time
logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
+ def _process_write_backlog(self):
+ # Try to make progress on the write backlog.
+ try:
+ for i in range(len(self._write_backlog)):
+ data, offset = self._write_backlog[0]
+ if data:
+ ssldata, offset = self._sslpipe.feed_appdata(data, offset)
+ elif offset:
+ ssldata, offset = self._sslpipe.do_handshake(self._on_handshake_complete), 1
+ else:
+ ssldata, offset = self._sslpipe.shutdown(self._ssl_eof_received), 1
+ # Temporarily set _closing to False to prevent
+ # _SelectorSocketTransport.write() from raising an error.
+ saved, self._closing = self._closing, False
+ for chunk in ssldata:
+ super().write(chunk)
+ self._closing = saved
+ if offset < len(data):
+ self._write_backlog[0][1] = offset
+ # A short write means that a write is blocked on a read
+ # We need to enable reading if it is not enabled!!
+ assert self._sslpipe.need_ssldata
+ if not self._reading:
+ self.resume_reading()
+ break
+ # An entire chunk from the backlog was processed. We can
+ # delete it and reduce the outstanding buffer size.
+ del self._write_backlog[0]
+ self._write_buffer_size -= offset
+ except BaseException as exc:
+ if self._in_handshake:
+ self._on_handshake_complete(exc)
+ else:
+ self._fatal_error(exc, 'Fatal error on SSL transport')
+
def pause_reading(self):
- # XXX This is a bit icky, given the comment at the top of
- # _read_ready(). Is it possible to evoke a deadlock? I don't
- # know, although it doesn't look like it; write() will still
- # accept more data for the buffer and eventually the app will
- # call resume_reading() again, and things will flow again.
-
if self._closing:
raise RuntimeError('Cannot pause_reading() when closing')
if self._paused:
raise RuntimeError('Already paused')
self._paused = True
- self._loop.remove_reader(self._sock_fd)
+ if not self._sslpipe.need_ssldata:
+ self._loop.remove_reader(self._sock_fd)
if self._loop.get_debug():
logger.debug("%r pauses reading", self)
- def resume_reading(self):
- if not self._paused:
- raise RuntimeError('Not paused')
- self._paused = False
- if self._closing:
- return
- self._loop.add_reader(self._sock_fd, self._read_ready)
- if self._loop.get_debug():
- logger.debug("%r resumes reading", self)
+ def _ssl_eof_received(self):
+ # An SSL EOF is received
+ # XXX we should be able to keep the connection alive (but cleartext)
+ self._eof_received()
+
+ def _eof_received(self):
+ # A lower-level EOF is received
+ try:
+ if not self._in_handshake:
+ if self._loop.get_debug():
+ logger.debug("%r received EOF", self)
+ keep_open = self._protocol.eof_received()
+ if keep_open:
+ logger.warning('returning true from eof_received() '
+ 'has no effect when using ssl')
+ finally:
+ self.close()
+
+ def _call_connection_lost(self, exc):
+ try:
+ if not self._in_handshake:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._finalize_connection()
def _read_ready(self):
- if self._write_wants_read:
- self._write_wants_read = False
- self._write_ready()
-
- if self._buffer:
- self._loop.add_writer(self._sock_fd, self._write_ready)
-
try:
data = self._sock.recv(self.max_size)
- except (BlockingIOError, InterruptedError, ssl.SSLWantReadError):
+ except (BlockingIOError, InterruptedError):
pass
- except ssl.SSLWantWriteError:
- self._read_wants_write = True
- self._loop.remove_reader(self._sock_fd)
- self._loop.add_writer(self._sock_fd, self._write_ready)
except Exception as exc:
self._fatal_error(exc, 'Fatal read error on SSL transport')
else:
if data:
- self._protocol.data_received(data)
+ try:
+ ssldata, appdata = self._sslpipe.feed_ssldata(data)
+ except ssl.SSLError as e:
+ logger.warning('SSL error {} (reason {})', e.errno, e.reason)
+ self.abort()
+ return
+ for chunk in ssldata:
+ super().write(chunk)
+ if appdata and self._paused:
+ self._loop.remove_reader(self._sock_fd)
+ for chunk in appdata:
+ self._protocol.data_received(chunk)
else:
- try:
- if self._loop.get_debug():
- logger.debug("%r received EOF", self)
- keep_open = self._protocol.eof_received()
- if keep_open:
- logger.warning('returning true from eof_received() '
- 'has no effect when using ssl')
- finally:
- self.close()
-
- def _write_ready(self):
- if self._read_wants_write:
- self._read_wants_write = False
- self._read_ready()
-
- if not (self._paused or self._closing):
- self._loop.add_reader(self._sock_fd, self._read_ready)
-
- if self._buffer:
- try:
- n = self._sock.send(self._buffer)
- except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError):
- n = 0
- except ssl.SSLWantReadError:
- n = 0
- self._loop.remove_writer(self._sock_fd)
- self._write_wants_read = True
- except Exception as exc:
- self._loop.remove_writer(self._sock_fd)
- self._buffer.clear()
- self._fatal_error(exc, 'Fatal write error on SSL transport')
- return
-
- if n:
- del self._buffer[:n]
-
- self._maybe_resume_protocol() # May append to buffer.
-
- if not self._buffer:
- self._loop.remove_writer(self._sock_fd)
- if self._closing:
- self._call_connection_lost(None)
+ # End of underlying stream
+ self._eof_received()
def write(self, data):
+ # Write *data* to the transport.
if not isinstance(data, (bytes, bytearray, memoryview)):
- raise TypeError('data argument must be byte-ish (%r)',
- type(data))
+ raise TypeError("data: expecting a bytes-like instance, got {!r}"
+ .format(type(data).__name__))
if not data:
return
+ self._write_backlog.append([data, 0])
+ self._write_buffer_size += len(data)
+ self._maybe_pause_protocol()
+ self._process_write_backlog()
- if self._conn_lost:
- if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
- logger.warning('socket.send() raised exception.')
- self._conn_lost += 1
- return
-
- if not self._buffer:
- self._loop.add_writer(self._sock_fd, self._write_ready)
-
- # Add it to the buffer.
- self._buffer.extend(data)
- self._maybe_pause_protocol()
+ def get_write_buffer_size(self):
+ return self._write_buffer_size
def can_write_eof(self):
return False
+ def write_eof(self):
+ _SelectorTransport.write_eof(self)
+
class _SelectorDatagramTransport(_SelectorTransport):
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -1096,17 +1096,16 @@ class SelectorSslTransportTests(test_uti
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
- self.sock.fileno.return_value = 7
- self.sslsock = mock.Mock()
- self.sslsock.fileno.return_value = 1
+ self.sock.fileno.return_value = 1
+ self.sslobj = mock.Mock()
self.sslcontext = mock.Mock()
- self.sslcontext.wrap_socket.return_value = self.sslsock
+ self.sslcontext.wrap_bio.return_value = self.sslobj
def _make_one(self, create_waiter=None):
transport = _SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext)
self.sock.reset_mock()
- self.sslsock.reset_mock()
+ self.sslobj.reset_mock()
self.sslcontext.reset_mock()
self.loop.reset_counters()
return transport
@@ -1116,45 +1115,40 @@ class SelectorSslTransportTests(test_uti
tr = _SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext,
waiter=waiter)
- self.assertTrue(self.sslsock.do_handshake.called)
+ self.assertTrue(self.sslobj.do_handshake.called)
self.loop.assert_reader(1, tr._read_ready)
test_utils.run_briefly(self.loop)
self.assertIsNone(waiter.result())
def test_on_handshake_reader_retry(self):
self.loop.set_debug(False)
- self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
+ self.sslobj.do_handshake.side_effect = ssl.SSLWantReadError(
+ ssl.SSL_ERROR_WANT_READ, '')
transport = _SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext)
- self.loop.assert_reader(1, transport._on_handshake, None)
-
- def test_on_handshake_writer_retry(self):
- self.loop.set_debug(False)
- self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError
- transport = _SelectorSslTransport(
- self.loop, self.sock, self.protocol, self.sslcontext)
- self.loop.assert_writer(1, transport._on_handshake, None)
+ self.loop.assert_reader(1, transport._read_ready)
def test_on_handshake_exc(self):
exc = ValueError()
- self.sslsock.do_handshake.side_effect = exc
+ self.sslobj.do_handshake.side_effect = exc
with test_utils.disable_logger():
waiter = asyncio.Future(loop=self.loop)
transport = _SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext, waiter)
self.assertTrue(waiter.done())
self.assertIs(exc, waiter.exception())
- self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(self.sock.close.called)
def test_on_handshake_base_exc(self):
transport = _SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext)
- transport._waiter = asyncio.Future(loop=self.loop)
exc = BaseException()
- self.sslsock.do_handshake.side_effect = exc
+ self.sslobj.do_handshake.side_effect = exc
with test_utils.disable_logger():
- self.assertRaises(BaseException, transport._on_handshake, 0)
- self.assertTrue(self.sslsock.close.called)
+ self.assertRaises(BaseException, transport.__init__,
+ self.loop, self.sock, self.protocol, self.sslcontext,
+ waiter=asyncio.Future(loop=self.loop))
+ self.assertTrue(self.sock.close.called)
self.assertTrue(transport._waiter.done())
self.assertIs(exc, transport._waiter.exception())
@@ -1171,77 +1165,20 @@ class SelectorSslTransportTests(test_uti
with self.assertRaises(RuntimeError):
tr.resume_reading()
- def test_write(self):
- transport = self._make_one()
- transport.write(b'data')
- self.assertEqual(list_to_buffer([b'data']), transport._buffer)
-
- def test_write_bytearray(self):
- transport = self._make_one()
- data = bytearray(b'data')
- transport.write(data)
- self.assertEqual(list_to_buffer([b'data']), transport._buffer)
- self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated.
- self.assertIsNot(data, transport._buffer) # Hasn't been incorporated.
-
- def test_write_memoryview(self):
- transport = self._make_one()
- data = memoryview(b'data')
- transport.write(data)
- self.assertEqual(list_to_buffer([b'data']), transport._buffer)
-
- def test_write_no_data(self):
- transport = self._make_one()
- transport._buffer.extend(b'data')
- transport.write(b'')
- self.assertEqual(list_to_buffer([b'data']), transport._buffer)
-
def test_write_str(self):
transport = self._make_one()
self.assertRaises(TypeError, transport.write, 'str')
- def test_write_closing(self):
+ def test_read_ready_recv(self):
transport = self._make_one()
- transport.close()
- self.assertEqual(transport._conn_lost, 1)
- transport.write(b'data')
- self.assertEqual(transport._conn_lost, 2)
-
- @mock.patch('asyncio.selector_events.logger')
- def test_write_exception(self, m_log):
- transport = self._make_one()
- transport._conn_lost = 1
- transport.write(b'data')
- self.assertEqual(transport._buffer, list_to_buffer())
- transport.write(b'data')
- transport.write(b'data')
- transport.write(b'data')
- transport.write(b'data')
- m_log.warning.assert_called_with('socket.send() raised exception.')
-
- def test_read_ready_recv(self):
- self.sslsock.recv.return_value = b'data'
- transport = self._make_one()
+ transport._sslpipe.feed_ssldata = mock.Mock(
+ return_value=([], [b'data']))
transport._read_ready()
- self.assertTrue(self.sslsock.recv.called)
+ self.assertTrue(self.sock.recv.called)
self.assertEqual((b'data',), self.protocol.data_received.call_args[0])
- def test_read_ready_write_wants_read(self):
- self.loop.add_writer = mock.Mock()
- self.sslsock.recv.side_effect = BlockingIOError
- transport = self._make_one()
- transport._write_wants_read = True
- transport._write_ready = mock.Mock()
- transport._buffer.extend(b'data')
- transport._read_ready()
-
- self.assertFalse(transport._write_wants_read)
- transport._write_ready.assert_called_with()
- self.loop.add_writer.assert_called_with(
- transport._sock_fd, transport._write_ready)
-
def test_read_ready_recv_eof(self):
- self.sslsock.recv.return_value = b''
+ self.sock.recv.return_value = b''
transport = self._make_one()
transport.close = mock.Mock()
transport._read_ready()
@@ -1249,43 +1186,15 @@ class SelectorSslTransportTests(test_uti
self.protocol.eof_received.assert_called_with()
def test_read_ready_recv_conn_reset(self):
- err = self.sslsock.recv.side_effect = ConnectionResetError()
+ err = self.sock.recv.side_effect = ConnectionResetError()
transport = self._make_one()
transport._force_close = mock.Mock()
with test_utils.disable_logger():
transport._read_ready()
transport._force_close.assert_called_with(err)
- def test_read_ready_recv_retry(self):
- self.sslsock.recv.side_effect = ssl.SSLWantReadError
- transport = self._make_one()
- transport._read_ready()
- self.assertTrue(self.sslsock.recv.called)
- self.assertFalse(self.protocol.data_received.called)
-
- self.sslsock.recv.side_effect = BlockingIOError
- transport._read_ready()
- self.assertFalse(self.protocol.data_received.called)
-
- self.sslsock.recv.side_effect = InterruptedError
- transport._read_ready()
- self.assertFalse(self.protocol.data_received.called)
-
- def test_read_ready_recv_write(self):
- self.loop.remove_reader = mock.Mock()
- self.loop.add_writer = mock.Mock()
- self.sslsock.recv.side_effect = ssl.SSLWantWriteError
- transport = self._make_one()
- transport._read_ready()
- self.assertFalse(self.protocol.data_received.called)
- self.assertTrue(transport._read_wants_write)
-
- self.loop.remove_reader.assert_called_with(transport._sock_fd)
- self.loop.add_writer.assert_called_with(
- transport._sock_fd, transport._write_ready)
-
def test_read_ready_recv_exc(self):
- err = self.sslsock.recv.side_effect = OSError()
+ err = self.sock.recv.side_effect = OSError()
transport = self._make_one()
transport._fatal_error = mock.Mock()
transport._read_ready()
@@ -1293,104 +1202,6 @@ class SelectorSslTransportTests(test_uti
err,
'Fatal read error on SSL transport')
- def test_write_ready_send(self):
- self.sslsock.send.return_value = 4
- transport = self._make_one()
- transport._buffer = list_to_buffer([b'data'])
- transport._write_ready()
- self.assertEqual(list_to_buffer(), transport._buffer)
- self.assertTrue(self.sslsock.send.called)
-
- def test_write_ready_send_none(self):
- self.sslsock.send.return_value = 0
- transport = self._make_one()
- transport._buffer = list_to_buffer([b'data1', b'data2'])
- transport._write_ready()
- self.assertTrue(self.sslsock.send.called)
- self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
-
- def test_write_ready_send_partial(self):
- self.sslsock.send.return_value = 2
- transport = self._make_one()
- transport._buffer = list_to_buffer([b'data1', b'data2'])
- transport._write_ready()
- self.assertTrue(self.sslsock.send.called)
- self.assertEqual(list_to_buffer([b'ta1data2']), transport._buffer)
-
- def test_write_ready_send_closing_partial(self):
- self.sslsock.send.return_value = 2
- transport = self._make_one()
- transport._buffer = list_to_buffer([b'data1', b'data2'])
- transport._write_ready()
- self.assertTrue(self.sslsock.send.called)
- self.assertFalse(self.sslsock.close.called)
-
- def test_write_ready_send_closing(self):
- self.sslsock.send.return_value = 4
- transport = self._make_one()
- transport.close()
- transport._buffer = list_to_buffer([b'data'])
- transport._write_ready()
- self.assertFalse(self.loop.writers)
- self.protocol.connection_lost.assert_called_with(None)
-
- def test_write_ready_send_closing_empty_buffer(self):
- self.sslsock.send.return_value = 4
- transport = self._make_one()
- transport.close()
- transport._buffer = list_to_buffer()
- transport._write_ready()
- self.assertFalse(self.loop.writers)
- self.protocol.connection_lost.assert_called_with(None)
-
- def test_write_ready_send_retry(self):
- transport = self._make_one()
- transport._buffer = list_to_buffer([b'data'])
-
- self.sslsock.send.side_effect = ssl.SSLWantWriteError
- transport._write_ready()
- self.assertEqual(list_to_buffer([b'data']), transport._buffer)
-
- self.sslsock.send.side_effect = BlockingIOError()
- transport._write_ready()
- self.assertEqual(list_to_buffer([b'data']), transport._buffer)
-
- def test_write_ready_send_read(self):
- transport = self._make_one()
- transport._buffer = list_to_buffer([b'data'])
-
- self.loop.remove_writer = mock.Mock()
- self.sslsock.send.side_effect = ssl.SSLWantReadError
- transport._write_ready()
- self.assertFalse(self.protocol.data_received.called)
- self.assertTrue(transport._write_wants_read)
- self.loop.remove_writer.assert_called_with(transport._sock_fd)
-
- def test_write_ready_send_exc(self):
- err = self.sslsock.send.side_effect = OSError()
-
- transport = self._make_one()
- transport._buffer = list_to_buffer([b'data'])
- transport._fatal_error = mock.Mock()
- transport._write_ready()
- transport._fatal_error.assert_called_with(
- err,
- 'Fatal write error on SSL transport')
- self.assertEqual(list_to_buffer(), transport._buffer)
-
- def test_write_ready_read_wants_write(self):
- self.loop.add_reader = mock.Mock()
- self.sslsock.send.side_effect = BlockingIOError
- transport = self._make_one()
- transport._read_wants_write = True
- transport._read_ready = mock.Mock()
- transport._write_ready()
-
- self.assertFalse(transport._read_wants_write)
- transport._read_ready.assert_called_with()
- self.loop.add_reader.assert_called_with(
- transport._sock_fd, transport._read_ready)
-
def test_write_eof(self):
tr = self._make_one()
self.assertFalse(tr.can_write_eof())
@@ -1413,9 +1224,8 @@ class SelectorSslTransportTests(test_uti
_SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext,
server_hostname='localhost')
- self.sslcontext.wrap_socket.assert_called_with(
- self.sock, do_handshake_on_connect=False, server_side=False,
- server_hostname='localhost')
+ self.sslcontext.wrap_bio.assert_called_with(
+ mock.ANY, mock.ANY, server_side=False, server_hostname='localhost')
class SelectorSslWithoutSslTransportTests(unittest.TestCase):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment