-
-
Save pitrou/f04fa9cbfec88cc37050 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py | |
--- a/Lib/asyncio/selector_events.py | |
+++ b/Lib/asyncio/selector_events.py | |
@@ -530,14 +530,19 @@ class _SelectorTransport(transports._Flo | |
try: | |
self._protocol.connection_lost(exc) | |
finally: | |
- self._sock.close() | |
- self._sock = None | |
- self._protocol = None | |
- self._loop = None | |
- server = self._server | |
- if server is not None: | |
- server._detach() | |
- self._server = None | |
+ self._finalize_connection() | |
+ | |
+ def _finalize_connection(self): | |
+ """Last steps to low-level close and unregister this transport. | |
+ """ | |
+ self._sock.close() | |
+ self._sock = None | |
+ self._protocol = None | |
+ self._loop = None | |
+ server = self._server | |
+ if server is not None: | |
+ server._detach() | |
+ self._server = None | |
def get_write_buffer_size(self): | |
return len(self._buffer) | |
@@ -667,7 +672,228 @@ class _SelectorSocketTransport(_Selector | |
return True | |
-class _SelectorSslTransport(_SelectorTransport): | |
+ | |
+class SSLPipe(object): | |
+ """An SSL "Pipe". | |
+ | |
+ An SSL pipe allows you to communicate with an SSL/TLS protocol instance | |
+ through memory buffers. It can be used to implement a security layer for an | |
+ existing connection where you don't have access to the connection's file | |
+ descriptor, or for some reason you don't want to use it. | |
+ | |
+ An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, | |
+ data is passed through untransformed. In wrapped mode, application level | |
+ data is encrypted to SSL record level data and vice versa. The SSL record | |
+ level is the lowest level in the SSL protocol suite and is what travels | |
+ as-is over the wire. | |
+ | |
+ An SslPipe initially is in "unwrapped" mode. To start SSL, call | |
+ :meth:`do_handshake`. To shutdown SSL again, call :meth:`unwrap`. | |
+ """ | |
+ | |
+ bufsize = 65536 | |
+ | |
+ # This previously used a socketpair to communicate with the SSL protocol | |
+ # instance but since October 2014 we're using a Memory BIO! This is | |
+ # cleaner, and more reliable on Windows. See for example issue #12 for more | |
+ # details. | |
+ | |
+ S_UNWRAPPED, S_DO_HANDSHAKE, S_WRAPPED, S_SHUTDOWN = range(4) | |
+ | |
+ def __init__(self, context, server_side, server_hostname=None): | |
+ """ | |
+ The *context* argument specifies the :class:`ssl.SSLContext` to use. | |
+ It is recommended to use :func:`~gruvi.ssl.create_ssl_context` so that | |
+ it will work on all supported Python versions. | |
+ | |
+ The *server_side* argument indicates whether this is a server side or | |
+ client side transport. | |
+ | |
+ The optional *server_hostname* argument can be used to specify the | |
+ hostname you are connecting to. You may only specify this parameter if | |
+ the _ssl module supports Server Name Indication (SNI). | |
+ """ | |
+ self._context = context | |
+ self._server_side = server_side | |
+ self._server_hostname = server_hostname | |
+ self._state = self.S_UNWRAPPED | |
+ self._incoming = ssl.MemoryBIO() | |
+ self._outgoing = ssl.MemoryBIO() | |
+ self._sslobj = None | |
+ self._need_ssldata = False | |
+ | |
+ @property | |
+ def context(self): | |
+ """The SSL context passed to the constructor.""" | |
+ return self._context | |
+ | |
+ @property | |
+ def ssl_object(self): | |
+ """The internal :class:`ssl.SSLObject` instance.""" | |
+ return self._sslobj | |
+ | |
+ @property | |
+ def need_ssldata(self): | |
+ """Whether more record level data is needed to complete a handshake | |
+ that is currently in progress.""" | |
+ return self._need_ssldata | |
+ | |
+ @property | |
+ def wrapped(self): | |
+ """Whether a security layer is currently in effect.""" | |
+ return self._state == self.S_WRAPPED | |
+ | |
+ def do_handshake(self, callback=None): | |
+ """Start the SSL handshake. Return a list of ssldata. | |
+ | |
+ The optional *callback* argument can be used to install a callback that | |
+ will be called when the handshake is complete. The callback will be | |
+ called with None if successful, else an exception instance. | |
+ """ | |
+ if self._state != self.S_UNWRAPPED: | |
+ raise RuntimeError('handshake in progress or completed') | |
+ wrapargs = () | |
+ self._sslobj = self._context.wrap_bio( | |
+ self._incoming, self._outgoing, | |
+ server_side=self._server_side, | |
+ server_hostname=self._server_hostname) | |
+ self._state = self.S_DO_HANDSHAKE | |
+ self._on_handshake_complete = callback | |
+ ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) | |
+ assert len(appdata) == 0 | |
+ return ssldata | |
+ | |
+ def shutdown(self, callback=None): | |
+ """Start the SSL shutdown sequence. Return a list of ssldata. | |
+ | |
+ The optional *callback* argument can be used to install a callback that | |
+ will be called when the shutdown is complete. The callback will be | |
+ called without arguments. | |
+ """ | |
+ if self._state == self.S_UNWRAPPED: | |
+ raise RuntimeError('no security layer present') | |
+ self._state = self.S_SHUTDOWN | |
+ self._on_handshake_complete = callback | |
+ ssldata, appdata = self.feed_ssldata(b'') | |
+ assert appdata == [] or appdata == [b''] | |
+ return ssldata | |
+ | |
+ def feed_eof(self): | |
+ """Send a potentially "ragged" EOF. | |
+ | |
+ This method will raise an SSL_ERROR_EOF exception if the EOF is | |
+ unexpected. | |
+ """ | |
+ self._incoming.write_eof() | |
+ ssldata, appdata = self.feed_ssldata(b'') | |
+ assert appdata == [] or appdata == [b''] | |
+ | |
+ def feed_ssldata(self, data, only_handshake=False): | |
+ """Feed SSL record level data into the pipe. | |
+ | |
+ The data must be a bytes instance. It is OK to send an empty bytes | |
+ instance. This can be used to get ssldata for a handshake initiated by | |
+ this endpoint. | |
+ | |
+ Return a (ssldata, appdata) tuple. The ssldata element is a list of | |
+ buffers containing SSL data that needs to be sent to the remote SSL. | |
+ | |
+ The appdata element is a list of buffers containing plaintext data that | |
+ needs to be forwarded to the application. The appdata list may contain | |
+ an empty buffer indicating an SSL "close_notify" alert. This alert must | |
+ be acknowledged by calling :meth:`shutdown`. | |
+ """ | |
+ if self._state == self.S_UNWRAPPED: | |
+ # If unwrapped, pass plaintext data straight through. | |
+ return ([], [data] if data else []) | |
+ ssldata = []; appdata = [] | |
+ self._need_ssldata = False | |
+ if data: | |
+ self._incoming.write(data) | |
+ try: | |
+ if self._state == self.S_DO_HANDSHAKE: | |
+ # Call do_handshake() until it doesn't raise anymore. | |
+ self._sslobj.do_handshake() | |
+ self._state = self.S_WRAPPED | |
+ if self._on_handshake_complete: | |
+ self._on_handshake_complete(None) | |
+ if only_handshake: | |
+ return (ssldata, appdata) | |
+ if self._state == self.S_WRAPPED: | |
+ # Main state: read data from SSL until close_notify | |
+ while True: | |
+ chunk = self._sslobj.read(self.bufsize) | |
+ appdata.append(chunk) | |
+ if not chunk: # close_notify | |
+ break | |
+ if self._state == self.S_SHUTDOWN: | |
+ # Call shutdown() until it doesn't raise anymore. | |
+ self._sslobj.unwrap() | |
+ self._sslobj = None | |
+ self._state = self.S_UNWRAPPED | |
+ if self._on_handshake_complete: | |
+ self._on_handshake_complete() | |
+ if self._state == self.S_UNWRAPPED: | |
+ # Drain possible plaintext data after close_notify. | |
+ appdata.append(self._incoming.read()) | |
+ except (ssl.SSLError, ssl.CertificateError) as e: | |
+ if getattr(e, 'errno', -1) not in ( | |
+ ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, | |
+ ssl.SSL_ERROR_SYSCALL): | |
+ if self._state == self.S_DO_HANDSHAKE and self._on_handshake_complete: | |
+ self._on_handshake_complete(e) | |
+ raise | |
+ self._need_ssldata = e.errno == ssl.SSL_ERROR_WANT_READ | |
+ # Check for record level data that needs to be sent back. | |
+ # Happens for the initial handshake and renegotiations. | |
+ if self._outgoing.pending: | |
+ ssldata.append(self._outgoing.read()) | |
+ return (ssldata, appdata) | |
+ | |
+ def feed_appdata(self, data, offset=0): | |
+ """Feed plaintext data into the pipe. | |
+ | |
+ Return an (ssldata, offset) tuple. The ssldata element is a list of | |
+ buffers containing record level data that needs to be sent to the | |
+ remote SSL instance. The offset is the number of plaintext bytes that | |
+ were processed, which may be less than the length of data. | |
+ | |
+ NOTE: In case of short writes, this call MUST be retried with the SAME | |
+ buffer passed into the *data* argument (i.e. the ``id()`` must be the | |
+ same). This is an OpenSSL requirement. A further particularity is that | |
+ a short write will always have offset == 0, because the _ssl module | |
+ does not enable partial writes. And even though the offset is zero, | |
+ there will still be encrypted data in ssldata. | |
+ """ | |
+ if self._state == self.S_UNWRAPPED: | |
+ # pass through data in unwrapped mode | |
+ return ([data[offset:]] if offset < len(data) else [], len(data)) | |
+ ssldata = [] | |
+ view = memoryview(data) | |
+ while True: | |
+ self._need_ssldata = False | |
+ try: | |
+ if offset < len(view): | |
+ offset += self._sslobj.write(view[offset:]) | |
+ except ssl.SSLError as e: | |
+ # It is not allowed to call write() after unwrap() until the | |
+ # close_notify is acknowledged. We return the condition to the | |
+ # caller as a short write. | |
+ if get_reason(e) == 'PROTOCOL_IS_SHUTDOWN': | |
+ e.errno = ssl.SSL_ERROR_WANT_READ | |
+ if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, | |
+ ssl.SSL_ERROR_SYSCALL): | |
+ raise | |
+ self._need_ssldata = e.errno == ssl.SSL_ERROR_WANT_READ | |
+ # See if there's any record level data back for us. | |
+ if self._outgoing.pending: | |
+ ssldata.append(self._outgoing.read()) | |
+ if offset == len(view) or self._need_ssldata: | |
+ break | |
+ return (ssldata, offset) | |
+ | |
+ | |
+class _SelectorSslTransport(_SelectorSocketTransport): | |
_buffer_factory = bytearray | |
@@ -697,48 +923,65 @@ class _SelectorSslTransport(_SelectorTra | |
sslcontext.set_default_verify_paths() | |
sslcontext.verify_mode = ssl.CERT_REQUIRED | |
- wrap_kwargs = { | |
- 'server_side': server_side, | |
- 'do_handshake_on_connect': False, | |
- } | |
- if server_hostname and not server_side and ssl.HAS_SNI: | |
- wrap_kwargs['server_hostname'] = server_hostname | |
- sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) | |
+ if not server_hostname or server_side or not ssl.HAS_SNI: | |
+ server_hostname = None | |
- super().__init__(loop, sslsock, protocol, extra, server) | |
+ _SelectorTransport.__init__(self, loop, rawsock, protocol, extra, server) | |
+ self._waiter = waiter | |
+ self._eof = False | |
+ self._paused = False | |
+ self._loop.add_reader(self._sock_fd, self._read_ready) | |
self._server_hostname = server_hostname | |
- self._waiter = waiter | |
self._sslcontext = sslcontext | |
- self._paused = False | |
- | |
+ self._sslpipe = SSLPipe(sslcontext, server_side, server_hostname) | |
+ # App data write buffering (_SelectorSocketTransport handles SSL data). | |
+ self._write_backlog = [] | |
+ self._write_buffer_size = 0 | |
# SSL-specific extra info. (peercert is set later) | |
self._extra.update(sslcontext=sslcontext) | |
+ self._start_handshake() | |
+ | |
+ def _start_handshake(self): | |
+ if self._closing: | |
+ raise TransportError('SSL transport is closing/closed') | |
if self._loop.get_debug(): | |
logger.debug("%r starts SSL handshake", self) | |
- start_time = self._loop.time() | |
+ self._handshake_start_time = self._loop.time() | |
else: | |
- start_time = None | |
- self._on_handshake(start_time) | |
+ self._handshake_start_time = None | |
+ self._in_handshake = True | |
+ self._write_backlog.append([b'', True]) | |
+ self._write_buffer_size += 1 | |
+ self._process_write_backlog() | |
- def _on_handshake(self, start_time): | |
+ def _on_handshake_complete(self, handshake_exc): | |
+ self._in_handshake = False | |
+ self._loop.remove_reader(self._sock_fd) | |
+ self._loop.remove_writer(self._sock_fd) | |
+ sslobj = self._sslpipe.ssl_object | |
+ | |
+ peercert = None if handshake_exc else sslobj.getpeercert() | |
+ | |
try: | |
- self._sock.do_handshake() | |
- except ssl.SSLWantReadError: | |
- self._loop.add_reader(self._sock_fd, | |
- self._on_handshake, start_time) | |
- return | |
- except ssl.SSLWantWriteError: | |
- self._loop.add_writer(self._sock_fd, | |
- self._on_handshake, start_time) | |
- return | |
+ if handshake_exc is not None: | |
+ raise handshake_exc | |
+ if not hasattr(self._sslcontext, 'check_hostname'): | |
+ # Verify hostname if requested, Python 3.4+ uses check_hostname | |
+ # and checks the hostname in do_handshake() | |
+ if (self._server_hostname and | |
+ self._sslcontext.verify_mode != ssl.CERT_NONE): | |
+ ssl.match_hostname(peercert, self._server_hostname) | |
except BaseException as exc: | |
if self._loop.get_debug(): | |
- logger.warning("%r: SSL handshake failed", | |
- self, exc_info=True) | |
- self._loop.remove_reader(self._sock_fd) | |
- self._loop.remove_writer(self._sock_fd) | |
+ if isinstance(exc, ssl.CertificateError): | |
+ logger.warning("%r: SSL handshake failed " | |
+ "on verifying the certificate", | |
+ self, exc_info=True) | |
+ else: | |
+ logger.warning("%r: SSL handshake failed", | |
+ self, exc_info=True) | |
self._sock.close() | |
if self._waiter is not None: | |
self._waiter.set_exception(exc) | |
@@ -747,160 +990,140 @@ class _SelectorSslTransport(_SelectorTra | |
else: | |
raise | |
- self._loop.remove_reader(self._sock_fd) | |
- self._loop.remove_writer(self._sock_fd) | |
- | |
- peercert = self._sock.getpeercert() | |
- if not hasattr(self._sslcontext, 'check_hostname'): | |
- # Verify hostname if requested, Python 3.4+ uses check_hostname | |
- # and checks the hostname in do_handshake() | |
- if (self._server_hostname and | |
- self._sslcontext.verify_mode != ssl.CERT_NONE): | |
- try: | |
- ssl.match_hostname(peercert, self._server_hostname) | |
- except Exception as exc: | |
- if self._loop.get_debug(): | |
- logger.warning("%r: SSL handshake failed " | |
- "on matching the hostname", | |
- self, exc_info=True) | |
- self._sock.close() | |
- if self._waiter is not None: | |
- self._waiter.set_exception(exc) | |
- return | |
- | |
# Add extra info that becomes available after handshake. | |
self._extra.update(peercert=peercert, | |
- cipher=self._sock.cipher(), | |
- compression=self._sock.compression(), | |
+ cipher=sslobj.cipher(), | |
+ compression=sslobj.compression(), | |
) | |
- self._read_wants_write = False | |
- self._write_wants_read = False | |
+ self._handshake_successful = True | |
+ self._protocol.connection_made(self) | |
+ if self._waiter is not None: | |
+ self._waiter._set_result_unless_cancelled(None) | |
+ | |
self._loop.add_reader(self._sock_fd, self._read_ready) | |
- self._loop.call_soon(self._protocol.connection_made, self) | |
- if self._waiter is not None: | |
- # wait until protocol.connection_made() has been called | |
- self._loop.call_soon(self._waiter._set_result_unless_cancelled, | |
- None) | |
if self._loop.get_debug(): | |
- dt = self._loop.time() - start_time | |
+ dt = self._loop.time() - self._handshake_start_time | |
logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) | |
+ def _process_write_backlog(self): | |
+ # Try to make progress on the write backlog. | |
+ try: | |
+ for i in range(len(self._write_backlog)): | |
+ data, offset = self._write_backlog[0] | |
+ if data: | |
+ ssldata, offset = self._sslpipe.feed_appdata(data, offset) | |
+ elif offset: | |
+ ssldata, offset = self._sslpipe.do_handshake(self._on_handshake_complete), 1 | |
+ else: | |
+ ssldata, offset = self._sslpipe.shutdown(self._ssl_eof_received), 1 | |
+ # Temporarily set _closing to False to prevent | |
+ # _SelectorSocketTransport.write() from raising an error. | |
+ saved, self._closing = self._closing, False | |
+ for chunk in ssldata: | |
+ super().write(chunk) | |
+ self._closing = saved | |
+ if offset < len(data): | |
+ self._write_backlog[0][1] = offset | |
+ # A short write means that a write is blocked on a read | |
+ # We need to enable reading if it is not enabled!! | |
+ assert self._sslpipe.need_ssldata | |
+ if not self._reading: | |
+ self.resume_reading() | |
+ break | |
+ # An entire chunk from the backlog was processed. We can | |
+ # delete it and reduce the outstanding buffer size. | |
+ del self._write_backlog[0] | |
+ self._write_buffer_size -= offset | |
+ except BaseException as exc: | |
+ if self._in_handshake: | |
+ self._on_handshake_complete(exc) | |
+ else: | |
+ self._fatal_error(exc, 'Fatal error on SSL transport') | |
+ | |
def pause_reading(self): | |
- # XXX This is a bit icky, given the comment at the top of | |
- # _read_ready(). Is it possible to evoke a deadlock? I don't | |
- # know, although it doesn't look like it; write() will still | |
- # accept more data for the buffer and eventually the app will | |
- # call resume_reading() again, and things will flow again. | |
- | |
if self._closing: | |
raise RuntimeError('Cannot pause_reading() when closing') | |
if self._paused: | |
raise RuntimeError('Already paused') | |
self._paused = True | |
- self._loop.remove_reader(self._sock_fd) | |
+ if not self._sslpipe.need_ssldata: | |
+ self._loop.remove_reader(self._sock_fd) | |
if self._loop.get_debug(): | |
logger.debug("%r pauses reading", self) | |
- def resume_reading(self): | |
- if not self._paused: | |
- raise RuntimeError('Not paused') | |
- self._paused = False | |
- if self._closing: | |
- return | |
- self._loop.add_reader(self._sock_fd, self._read_ready) | |
- if self._loop.get_debug(): | |
- logger.debug("%r resumes reading", self) | |
+ def _ssl_eof_received(self): | |
+ # An SSL EOF is received | |
+ # XXX we should be able to keep the connection alive (but cleartext) | |
+ self._eof_received() | |
+ | |
+ def _eof_received(self): | |
+ # A lower-level EOF is received | |
+ try: | |
+ if not self._in_handshake: | |
+ if self._loop.get_debug(): | |
+ logger.debug("%r received EOF", self) | |
+ keep_open = self._protocol.eof_received() | |
+ if keep_open: | |
+ logger.warning('returning true from eof_received() ' | |
+ 'has no effect when using ssl') | |
+ finally: | |
+ self.close() | |
+ | |
+ def _call_connection_lost(self, exc): | |
+ try: | |
+ if not self._in_handshake: | |
+ self._protocol.connection_lost(exc) | |
+ finally: | |
+ self._finalize_connection() | |
def _read_ready(self): | |
- if self._write_wants_read: | |
- self._write_wants_read = False | |
- self._write_ready() | |
- | |
- if self._buffer: | |
- self._loop.add_writer(self._sock_fd, self._write_ready) | |
- | |
try: | |
data = self._sock.recv(self.max_size) | |
- except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): | |
+ except (BlockingIOError, InterruptedError): | |
pass | |
- except ssl.SSLWantWriteError: | |
- self._read_wants_write = True | |
- self._loop.remove_reader(self._sock_fd) | |
- self._loop.add_writer(self._sock_fd, self._write_ready) | |
except Exception as exc: | |
self._fatal_error(exc, 'Fatal read error on SSL transport') | |
else: | |
if data: | |
- self._protocol.data_received(data) | |
+ try: | |
+ ssldata, appdata = self._sslpipe.feed_ssldata(data) | |
+ except ssl.SSLError as e: | |
+ logger.warning('SSL error {} (reason {})', e.errno, e.reason) | |
+ self.abort() | |
+ return | |
+ for chunk in ssldata: | |
+ super().write(chunk) | |
+ if appdata and self._paused: | |
+ self._loop.remove_reader(self._sock_fd) | |
+ for chunk in appdata: | |
+ self._protocol.data_received(chunk) | |
else: | |
- try: | |
- if self._loop.get_debug(): | |
- logger.debug("%r received EOF", self) | |
- keep_open = self._protocol.eof_received() | |
- if keep_open: | |
- logger.warning('returning true from eof_received() ' | |
- 'has no effect when using ssl') | |
- finally: | |
- self.close() | |
- | |
- def _write_ready(self): | |
- if self._read_wants_write: | |
- self._read_wants_write = False | |
- self._read_ready() | |
- | |
- if not (self._paused or self._closing): | |
- self._loop.add_reader(self._sock_fd, self._read_ready) | |
- | |
- if self._buffer: | |
- try: | |
- n = self._sock.send(self._buffer) | |
- except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): | |
- n = 0 | |
- except ssl.SSLWantReadError: | |
- n = 0 | |
- self._loop.remove_writer(self._sock_fd) | |
- self._write_wants_read = True | |
- except Exception as exc: | |
- self._loop.remove_writer(self._sock_fd) | |
- self._buffer.clear() | |
- self._fatal_error(exc, 'Fatal write error on SSL transport') | |
- return | |
- | |
- if n: | |
- del self._buffer[:n] | |
- | |
- self._maybe_resume_protocol() # May append to buffer. | |
- | |
- if not self._buffer: | |
- self._loop.remove_writer(self._sock_fd) | |
- if self._closing: | |
- self._call_connection_lost(None) | |
+ # End of underlying stream | |
+ self._eof_received() | |
def write(self, data): | |
+ # Write *data* to the transport. | |
if not isinstance(data, (bytes, bytearray, memoryview)): | |
- raise TypeError('data argument must be byte-ish (%r)', | |
- type(data)) | |
+ raise TypeError("data: expecting a bytes-like instance, got {!r}" | |
+ .format(type(data).__name__)) | |
if not data: | |
return | |
+ self._write_backlog.append([data, 0]) | |
+ self._write_buffer_size += len(data) | |
+ self._maybe_pause_protocol() | |
+ self._process_write_backlog() | |
- if self._conn_lost: | |
- if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: | |
- logger.warning('socket.send() raised exception.') | |
- self._conn_lost += 1 | |
- return | |
- | |
- if not self._buffer: | |
- self._loop.add_writer(self._sock_fd, self._write_ready) | |
- | |
- # Add it to the buffer. | |
- self._buffer.extend(data) | |
- self._maybe_pause_protocol() | |
+ def get_write_buffer_size(self): | |
+ return self._write_buffer_size | |
def can_write_eof(self): | |
return False | |
+ def write_eof(self): | |
+ _SelectorTransport.write_eof(self) | |
+ | |
class _SelectorDatagramTransport(_SelectorTransport): | |
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py | |
--- a/Lib/test/test_asyncio/test_selector_events.py | |
+++ b/Lib/test/test_asyncio/test_selector_events.py | |
@@ -1096,17 +1096,16 @@ class SelectorSslTransportTests(test_uti | |
self.loop = self.new_test_loop() | |
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) | |
self.sock = mock.Mock(socket.socket) | |
- self.sock.fileno.return_value = 7 | |
- self.sslsock = mock.Mock() | |
- self.sslsock.fileno.return_value = 1 | |
+ self.sock.fileno.return_value = 1 | |
+ self.sslobj = mock.Mock() | |
self.sslcontext = mock.Mock() | |
- self.sslcontext.wrap_socket.return_value = self.sslsock | |
+ self.sslcontext.wrap_bio.return_value = self.sslobj | |
def _make_one(self, create_waiter=None): | |
transport = _SelectorSslTransport( | |
self.loop, self.sock, self.protocol, self.sslcontext) | |
self.sock.reset_mock() | |
- self.sslsock.reset_mock() | |
+ self.sslobj.reset_mock() | |
self.sslcontext.reset_mock() | |
self.loop.reset_counters() | |
return transport | |
@@ -1116,45 +1115,40 @@ class SelectorSslTransportTests(test_uti | |
tr = _SelectorSslTransport( | |
self.loop, self.sock, self.protocol, self.sslcontext, | |
waiter=waiter) | |
- self.assertTrue(self.sslsock.do_handshake.called) | |
+ self.assertTrue(self.sslobj.do_handshake.called) | |
self.loop.assert_reader(1, tr._read_ready) | |
test_utils.run_briefly(self.loop) | |
self.assertIsNone(waiter.result()) | |
def test_on_handshake_reader_retry(self): | |
self.loop.set_debug(False) | |
- self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError | |
+ self.sslobj.do_handshake.side_effect = ssl.SSLWantReadError( | |
+ ssl.SSL_ERROR_WANT_READ, '') | |
transport = _SelectorSslTransport( | |
self.loop, self.sock, self.protocol, self.sslcontext) | |
- self.loop.assert_reader(1, transport._on_handshake, None) | |
- | |
- def test_on_handshake_writer_retry(self): | |
- self.loop.set_debug(False) | |
- self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError | |
- transport = _SelectorSslTransport( | |
- self.loop, self.sock, self.protocol, self.sslcontext) | |
- self.loop.assert_writer(1, transport._on_handshake, None) | |
+ self.loop.assert_reader(1, transport._read_ready) | |
def test_on_handshake_exc(self): | |
exc = ValueError() | |
- self.sslsock.do_handshake.side_effect = exc | |
+ self.sslobj.do_handshake.side_effect = exc | |
with test_utils.disable_logger(): | |
waiter = asyncio.Future(loop=self.loop) | |
transport = _SelectorSslTransport( | |
self.loop, self.sock, self.protocol, self.sslcontext, waiter) | |
self.assertTrue(waiter.done()) | |
self.assertIs(exc, waiter.exception()) | |
- self.assertTrue(self.sslsock.close.called) | |
+ self.assertTrue(self.sock.close.called) | |
def test_on_handshake_base_exc(self): | |
transport = _SelectorSslTransport( | |
self.loop, self.sock, self.protocol, self.sslcontext) | |
- transport._waiter = asyncio.Future(loop=self.loop) | |
exc = BaseException() | |
- self.sslsock.do_handshake.side_effect = exc | |
+ self.sslobj.do_handshake.side_effect = exc | |
with test_utils.disable_logger(): | |
- self.assertRaises(BaseException, transport._on_handshake, 0) | |
- self.assertTrue(self.sslsock.close.called) | |
+ self.assertRaises(BaseException, transport.__init__, | |
+ self.loop, self.sock, self.protocol, self.sslcontext, | |
+ waiter=asyncio.Future(loop=self.loop)) | |
+ self.assertTrue(self.sock.close.called) | |
self.assertTrue(transport._waiter.done()) | |
self.assertIs(exc, transport._waiter.exception()) | |
@@ -1171,77 +1165,20 @@ class SelectorSslTransportTests(test_uti | |
with self.assertRaises(RuntimeError): | |
tr.resume_reading() | |
- def test_write(self): | |
- transport = self._make_one() | |
- transport.write(b'data') | |
- self.assertEqual(list_to_buffer([b'data']), transport._buffer) | |
- | |
- def test_write_bytearray(self): | |
- transport = self._make_one() | |
- data = bytearray(b'data') | |
- transport.write(data) | |
- self.assertEqual(list_to_buffer([b'data']), transport._buffer) | |
- self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. | |
- self.assertIsNot(data, transport._buffer) # Hasn't been incorporated. | |
- | |
- def test_write_memoryview(self): | |
- transport = self._make_one() | |
- data = memoryview(b'data') | |
- transport.write(data) | |
- self.assertEqual(list_to_buffer([b'data']), transport._buffer) | |
- | |
- def test_write_no_data(self): | |
- transport = self._make_one() | |
- transport._buffer.extend(b'data') | |
- transport.write(b'') | |
- self.assertEqual(list_to_buffer([b'data']), transport._buffer) | |
- | |
def test_write_str(self): | |
transport = self._make_one() | |
self.assertRaises(TypeError, transport.write, 'str') | |
- def test_write_closing(self): | |
+ def test_read_ready_recv(self): | |
transport = self._make_one() | |
- transport.close() | |
- self.assertEqual(transport._conn_lost, 1) | |
- transport.write(b'data') | |
- self.assertEqual(transport._conn_lost, 2) | |
- | |
- @mock.patch('asyncio.selector_events.logger') | |
- def test_write_exception(self, m_log): | |
- transport = self._make_one() | |
- transport._conn_lost = 1 | |
- transport.write(b'data') | |
- self.assertEqual(transport._buffer, list_to_buffer()) | |
- transport.write(b'data') | |
- transport.write(b'data') | |
- transport.write(b'data') | |
- transport.write(b'data') | |
- m_log.warning.assert_called_with('socket.send() raised exception.') | |
- | |
- def test_read_ready_recv(self): | |
- self.sslsock.recv.return_value = b'data' | |
- transport = self._make_one() | |
+ transport._sslpipe.feed_ssldata = mock.Mock( | |
+ return_value=([], [b'data'])) | |
transport._read_ready() | |
- self.assertTrue(self.sslsock.recv.called) | |
+ self.assertTrue(self.sock.recv.called) | |
self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) | |
- def test_read_ready_write_wants_read(self): | |
- self.loop.add_writer = mock.Mock() | |
- self.sslsock.recv.side_effect = BlockingIOError | |
- transport = self._make_one() | |
- transport._write_wants_read = True | |
- transport._write_ready = mock.Mock() | |
- transport._buffer.extend(b'data') | |
- transport._read_ready() | |
- | |
- self.assertFalse(transport._write_wants_read) | |
- transport._write_ready.assert_called_with() | |
- self.loop.add_writer.assert_called_with( | |
- transport._sock_fd, transport._write_ready) | |
- | |
def test_read_ready_recv_eof(self): | |
- self.sslsock.recv.return_value = b'' | |
+ self.sock.recv.return_value = b'' | |
transport = self._make_one() | |
transport.close = mock.Mock() | |
transport._read_ready() | |
@@ -1249,43 +1186,15 @@ class SelectorSslTransportTests(test_uti | |
self.protocol.eof_received.assert_called_with() | |
def test_read_ready_recv_conn_reset(self): | |
- err = self.sslsock.recv.side_effect = ConnectionResetError() | |
+ err = self.sock.recv.side_effect = ConnectionResetError() | |
transport = self._make_one() | |
transport._force_close = mock.Mock() | |
with test_utils.disable_logger(): | |
transport._read_ready() | |
transport._force_close.assert_called_with(err) | |
- def test_read_ready_recv_retry(self): | |
- self.sslsock.recv.side_effect = ssl.SSLWantReadError | |
- transport = self._make_one() | |
- transport._read_ready() | |
- self.assertTrue(self.sslsock.recv.called) | |
- self.assertFalse(self.protocol.data_received.called) | |
- | |
- self.sslsock.recv.side_effect = BlockingIOError | |
- transport._read_ready() | |
- self.assertFalse(self.protocol.data_received.called) | |
- | |
- self.sslsock.recv.side_effect = InterruptedError | |
- transport._read_ready() | |
- self.assertFalse(self.protocol.data_received.called) | |
- | |
- def test_read_ready_recv_write(self): | |
- self.loop.remove_reader = mock.Mock() | |
- self.loop.add_writer = mock.Mock() | |
- self.sslsock.recv.side_effect = ssl.SSLWantWriteError | |
- transport = self._make_one() | |
- transport._read_ready() | |
- self.assertFalse(self.protocol.data_received.called) | |
- self.assertTrue(transport._read_wants_write) | |
- | |
- self.loop.remove_reader.assert_called_with(transport._sock_fd) | |
- self.loop.add_writer.assert_called_with( | |
- transport._sock_fd, transport._write_ready) | |
- | |
def test_read_ready_recv_exc(self): | |
- err = self.sslsock.recv.side_effect = OSError() | |
+ err = self.sock.recv.side_effect = OSError() | |
transport = self._make_one() | |
transport._fatal_error = mock.Mock() | |
transport._read_ready() | |
@@ -1293,104 +1202,6 @@ class SelectorSslTransportTests(test_uti | |
err, | |
'Fatal read error on SSL transport') | |
- def test_write_ready_send(self): | |
- self.sslsock.send.return_value = 4 | |
- transport = self._make_one() | |
- transport._buffer = list_to_buffer([b'data']) | |
- transport._write_ready() | |
- self.assertEqual(list_to_buffer(), transport._buffer) | |
- self.assertTrue(self.sslsock.send.called) | |
- | |
- def test_write_ready_send_none(self): | |
- self.sslsock.send.return_value = 0 | |
- transport = self._make_one() | |
- transport._buffer = list_to_buffer([b'data1', b'data2']) | |
- transport._write_ready() | |
- self.assertTrue(self.sslsock.send.called) | |
- self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer) | |
- | |
- def test_write_ready_send_partial(self): | |
- self.sslsock.send.return_value = 2 | |
- transport = self._make_one() | |
- transport._buffer = list_to_buffer([b'data1', b'data2']) | |
- transport._write_ready() | |
- self.assertTrue(self.sslsock.send.called) | |
- self.assertEqual(list_to_buffer([b'ta1data2']), transport._buffer) | |
- | |
- def test_write_ready_send_closing_partial(self): | |
- self.sslsock.send.return_value = 2 | |
- transport = self._make_one() | |
- transport._buffer = list_to_buffer([b'data1', b'data2']) | |
- transport._write_ready() | |
- self.assertTrue(self.sslsock.send.called) | |
- self.assertFalse(self.sslsock.close.called) | |
- | |
- def test_write_ready_send_closing(self): | |
- self.sslsock.send.return_value = 4 | |
- transport = self._make_one() | |
- transport.close() | |
- transport._buffer = list_to_buffer([b'data']) | |
- transport._write_ready() | |
- self.assertFalse(self.loop.writers) | |
- self.protocol.connection_lost.assert_called_with(None) | |
- | |
- def test_write_ready_send_closing_empty_buffer(self): | |
- self.sslsock.send.return_value = 4 | |
- transport = self._make_one() | |
- transport.close() | |
- transport._buffer = list_to_buffer() | |
- transport._write_ready() | |
- self.assertFalse(self.loop.writers) | |
- self.protocol.connection_lost.assert_called_with(None) | |
- | |
- def test_write_ready_send_retry(self): | |
- transport = self._make_one() | |
- transport._buffer = list_to_buffer([b'data']) | |
- | |
- self.sslsock.send.side_effect = ssl.SSLWantWriteError | |
- transport._write_ready() | |
- self.assertEqual(list_to_buffer([b'data']), transport._buffer) | |
- | |
- self.sslsock.send.side_effect = BlockingIOError() | |
- transport._write_ready() | |
- self.assertEqual(list_to_buffer([b'data']), transport._buffer) | |
- | |
- def test_write_ready_send_read(self): | |
- transport = self._make_one() | |
- transport._buffer = list_to_buffer([b'data']) | |
- | |
- self.loop.remove_writer = mock.Mock() | |
- self.sslsock.send.side_effect = ssl.SSLWantReadError | |
- transport._write_ready() | |
- self.assertFalse(self.protocol.data_received.called) | |
- self.assertTrue(transport._write_wants_read) | |
- self.loop.remove_writer.assert_called_with(transport._sock_fd) | |
- | |
- def test_write_ready_send_exc(self): | |
- err = self.sslsock.send.side_effect = OSError() | |
- | |
- transport = self._make_one() | |
- transport._buffer = list_to_buffer([b'data']) | |
- transport._fatal_error = mock.Mock() | |
- transport._write_ready() | |
- transport._fatal_error.assert_called_with( | |
- err, | |
- 'Fatal write error on SSL transport') | |
- self.assertEqual(list_to_buffer(), transport._buffer) | |
- | |
- def test_write_ready_read_wants_write(self): | |
- self.loop.add_reader = mock.Mock() | |
- self.sslsock.send.side_effect = BlockingIOError | |
- transport = self._make_one() | |
- transport._read_wants_write = True | |
- transport._read_ready = mock.Mock() | |
- transport._write_ready() | |
- | |
- self.assertFalse(transport._read_wants_write) | |
- transport._read_ready.assert_called_with() | |
- self.loop.add_reader.assert_called_with( | |
- transport._sock_fd, transport._read_ready) | |
- | |
def test_write_eof(self): | |
tr = self._make_one() | |
self.assertFalse(tr.can_write_eof()) | |
@@ -1413,9 +1224,8 @@ class SelectorSslTransportTests(test_uti | |
_SelectorSslTransport( | |
self.loop, self.sock, self.protocol, self.sslcontext, | |
server_hostname='localhost') | |
- self.sslcontext.wrap_socket.assert_called_with( | |
- self.sock, do_handshake_on_connect=False, server_side=False, | |
- server_hostname='localhost') | |
+ self.sslcontext.wrap_bio.assert_called_with( | |
+ mock.ANY, mock.ANY, server_side=False, server_hostname='localhost') | |
class SelectorSslWithoutSslTransportTests(unittest.TestCase): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment