Created
May 23, 2023 18:18
-
-
Save rob-blackbourn/100dab4ebd7952c41e9bf2f078e6022b to your computer and use it in GitHub Desktop.
Using asyncio start_tls with Python 3.11
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Example Client | |
The client connects without TLS, but using the fully qualified domain name. To | |
authenticate the server, the FQDN is required. This can be specified either at | |
connection time, or with the start_tls call. | |
First the client sends a "PING" over the unencrypted stream to the server. The | |
server should respond with "PONG". | |
Next the client sends "STARTTLS" to instruct the server to upgrade the | |
connection to TLS. The client then calls the upgrade method on the writer to | |
negotiate the upgrade. | |
The client sends "PING" to the server, this time over the encrypted stream. The | |
server should respond with "PONG". | |
Finally the client sends "QUIT" to the server and closes the connection. | |
""" | |
import asyncio | |
import socket | |
import ssl | |
async def start_client(): | |
print("Connect to the server with using the fully qualified domain name") | |
reader, writer = await asyncio.open_connection(socket.getfqdn(), 10001) | |
print(f"The writer ssl context is {writer.get_extra_info('sslcontext')}") | |
print("Sending PING") | |
writer.write(b'PING\n') | |
response = (await reader.readline()).decode('utf-8').rstrip() | |
print(f"Received: {response}") | |
print("Sending STARTTLS") | |
writer.write(b'STARTTLS\n') | |
print("Upgrade the connection to TLS") | |
ctx = ssl.create_default_context( | |
purpose=ssl.Purpose.SERVER_AUTH, | |
cafile='/etc/ssl/certs/ca-certificates.crt' | |
) | |
await writer.start_tls(ctx) | |
print(f"The writer ssl context is {writer.get_extra_info('sslcontext')}") | |
print("Sending PING") | |
writer.write(b'PING\n') | |
response = (await reader.readline()).decode('utf-8').rstrip() | |
print(f"Received: {response}") | |
print("Sending QUIT") | |
writer.write(b'QUIT\n') | |
await writer.drain() | |
print("Closing client") | |
writer.close() | |
await writer.wait_closed() | |
print("Client disconnected") | |
if __name__ == '__main__': | |
asyncio.run(start_client()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Example Server | |
The server listens for client connections. | |
On receiving a connection it enters a read loop. | |
When the server receives "PING" it responds with "PONG". | |
When the server receives "QUIT" it closes the connection. | |
When the server receives "STARTTLS" it calls the upgrade method on the writer | |
to negotiate the TLS connection. | |
""" | |
import asyncio | |
from asyncio import StreamReader, StreamWriter | |
from functools import partial | |
from os.path import expanduser | |
import socket | |
import ssl | |
async def handle_client( | |
ctx: ssl.SSLContext, | |
reader: StreamReader, | |
writer: StreamWriter | |
) -> None: | |
print("Client connected") | |
while True: | |
request = (await reader.readline()).decode('utf8').rstrip() | |
print(f"Read '{request}'") | |
if request == 'QUIT': | |
break | |
elif request == 'PING': | |
print("Sending pong") | |
writer.write(b'PONG\n') | |
await writer.drain() | |
elif request == 'STARTTLS': | |
print("Upgrading connection to TLS") | |
await writer.start_tls(ctx) | |
print("Closing client") | |
writer.close() | |
await writer.wait_closed() | |
print("Client closed") | |
async def run_server(): | |
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
ctx.load_verify_locations(cafile="/etc/ssl/certs/ca-certificates.crt") | |
ctx.load_cert_chain( | |
expanduser("~/.keys/server.crt"), | |
expanduser("~/.keys/server.key") | |
) | |
handler = partial(handle_client, ctx) | |
print("Starting server") | |
server = await asyncio.start_server(handler, socket.getfqdn(), 10001) | |
async with server: | |
await server.serve_forever() | |
if __name__ == '__main__': | |
asyncio.run(run_server()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment