Skip to content

Instantly share code, notes, and snippets.

@agronholm
Created August 25, 2020 20:36
Show Gist options
  • Save agronholm/f2d674a7d3b1ecfd585901dfb771f883 to your computer and use it in GitHub Desktop.
Save agronholm/f2d674a7d3b1ecfd585901dfb771f883 to your computer and use it in GitHub Desktop.
Typed dynamic attributes for AnyIO
class TypedAttribute(Generic[T_Attr]):
"""
Generic class used to define typed attributes, for use in :class:`~TypedAttributeContainer`.
"""
class TypedAttributeContainer(metaclass=ABCMeta):
"""
Base class for classes that wish to provide typed extra attributes.
"""
@overload
def get_extra_attribute(self, attribute: TypedAttribute[T_Attr]) -> T_Attr:
...
@overload
def get_extra_attribute(self, attribute: TypedAttribute[T_Attr],
default: T_Default) -> Union[T_Attr, T_Default]:
...
@abstractmethod
def get_extra_attribute(self, attribute, default=undefined):
"""
Return the value of the given attribute.
If the container wraps another container and the value cannot be found in this instance, it
should delegate the search to the wrapped container.
:param attribute: the attribute to look for
:param default: the value that should be returned if no value is found for the attribute
"""
class TLSAttribute:
"""Contains attribute keys provided by :class:`TLSStream`."""
alpn_protocol = TypedAttribute[bool]() #: the selected ALPN protocol
cipher = TypedAttribute[Tuple[str, str, int]]() #: the selected cipher
peer_certificate = TypedAttribute[Dict[str, Union[str, tuple]]]() #: the peer certificate
peer_certificate_binary = TypedAttribute[bytes]() #: the peer certificate in binary form
server_side = TypedAttribute[bool]() #: ``True`` if this is the server side of the connection
#: ciphers shared between both ends of the TLS connection
shared_ciphers = TypedAttribute[List[Tuple[str, str, int]]]()
#: the :class:`~ssl.SSLObject` used for encryption
ssl_object = TypedAttribute[ssl.SSLObject]()
#: ``True`` if this stream does, and expects a closing TLS handshake when the stream is being
# closed
standard_compatible = TypedAttribute[bool]()
tls_version = TypedAttribute[str]()
class TLSStream(ByteStream):
def get_extra_attribute(self, attribute: TypedAttribute[T_Attr],
default: T_Default = None) -> Union[T_Attr, T_Default]:
if attribute is TLSAttribute.alpn_protocol:
return self._ssl_object.selected_alpn_protocol()
elif attribute is TLSAttribute.cipher:
self._ssl_object.cipher()
elif attribute is TLSAttribute.peer_certificate:
return self._ssl_object.getpeercert(False)
elif attribute is TLSAttribute.peer_certificate_binary:
return self._ssl_object.getpeercert(True)
elif attribute is TLSAttribute.server_side:
return self._ssl_object.server_side
elif attribute is TLSAttribute.shared_ciphers:
return self._ssl_object.shared_ciphers()
elif attribute is TLSAttribute.ssl_object:
return self._ssl_object
elif attribute is TLSAttribute.tls_version:
return self._ssl_object.version()
return self.transport_stream.get_extra_attribute(attribute, default)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment