Created
October 5, 2024 21:02
-
-
Save dustinlennon/795b5c283a21b47d002c88ae79e64917 to your computer and use it in GitHub Desktop.
a fork and monitor pattern using asyncio and named sockets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A fork and monitor pattern, asyncio with named sockets | |
author: Dustin Lennon | |
email: dustin.lennon@gmail.com | |
Generate heartbeats in a parent process; check heartbeats in a child process. The | |
most recent heartbeat is saved in a file. The child exits when the parent terminates, | |
specifically when its parent pid changes. | |
This code also sets up a named socket in the filespace which enables shell interaction with | |
the monitor process. This should be useful for debugging, e.g: | |
echo "hello" | socat - UNIX-CLIENT:monitor.socket | |
echo -n "" | socat - UNIX-CLIENT:monitor.socket | |
socat - UNIX-CLIENT:monitor.socket | |
asyncio.open_unix_connection can be a bit fussy with being handed a socket. In particular, | |
it expects an already accepted socket on which it will block indefinitely if, say, the | |
client connects and does nothing. So, we provide a `safe_unix_connection` async context | |
manager to make sure it doesn't get stuck and that the writer is closed appropriately. | |
Server logic is encapsulated much like the callback function in asyncio.start_server. | |
""" | |
import socket | |
import struct | |
import asyncio | |
import os | |
from pathlib import Path | |
import time | |
from contextlib import asynccontextmanager | |
# The heartbeat file contains the most recent heartbeat | |
heartbeat_path = "/tmp/monitor.heartbeat" | |
# A named socket allowing control from the shell | |
socket_path = "/tmp/monitor.socket" | |
# The timeout for connecting and read/write operations | |
connect_timeout_s = 3 | |
# The socket timeout used for accept, this becomes the de facto polling interval to | |
# check if the parent process has terminated / the child has been orphaned. | |
socket_timeout_s = 1 | |
@asynccontextmanager | |
async def safe_unix_connection(conn : socket.socket, timeout : float | None): | |
"""Call asyncio.open_unix_connection using an accepted socket. | |
This calls asyncio.open_unix_connection within an asyncio.timeout and yields a tuple | |
(StreamReader, StreamWriter). On exit, it closes the writer. This addresses the | |
scenario where a client connects and then does nothing; this should be preempted | |
to prevent hanging on a subsequent read. | |
Args: | |
conn: an accepted socket | |
timeout: a timeout for the asyncio.timeout call | |
""" | |
try: | |
async with asyncio.timeout(timeout): | |
reader, writer = await asyncio.open_unix_connection(sock = conn) | |
yield (reader, writer) | |
except TimeoutError: | |
print("monitor: timed out") | |
pass | |
finally: | |
writer.close() | |
await writer.wait_closed() | |
async def monitor(ppid : int): | |
"""Loop over accepted socket connections, create a safe unix connection, and handle request. | |
The loop ends when the parent process id changes as this is a reliable indicator that the | |
parent has terminated. | |
Args: | |
ppid: the parent process id | |
""" | |
Path(socket_path).unlink(missing_ok = True) | |
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
s.settimeout(socket_timeout_s) | |
s.bind(socket_path) | |
s.listen() | |
# keep a reference for active tasks | |
background_tasks = set() | |
# postcondition: the parent has already terminated | |
while os.getppid() == ppid: | |
try: | |
conn, _ = s.accept() | |
except TimeoutError: | |
continue | |
else: | |
async with safe_unix_connection(conn, connect_timeout_s) as streams: | |
task = asyncio.create_task(handle_request(*streams)) | |
background_tasks.add(task) | |
task.add_done_callback(background_tasks.discard) | |
await asyncio.sleep(0) | |
# clean up | |
s.close() | |
Path(socket_path).unlink(missing_ok = True) | |
Path(heartbeat_path).unlink(missing_ok = True) | |
def main(): | |
"""Initialize and fork driver and monitor processes""" | |
initialize() | |
ppid = os.getpid() | |
pid = os.fork() | |
if pid == 0: | |
coro = monitor(ppid) | |
else: | |
coro = driver() | |
try: | |
asyncio.run(coro) | |
except KeyboardInterrupt: | |
pass | |
# start: application-specific logic | |
def initialize(): | |
"""Send a heartbeat to update the heartbeat file.""" | |
send_heartbeat() | |
def send_heartbeat(): | |
"""Write the most recent heartbeat timestamp to the heartbeat file.""" | |
with open(heartbeat_path, "wb") as f: | |
msg = struct.pack('d', time.time()) | |
f.write(msg) | |
f.flush() | |
def last_heartbeat(): | |
"""Read the most recent heartbeat timestamp from the heartbeat file.""" | |
with open(heartbeat_path, "rb") as f: | |
msg = f.read(8) | |
heartbeat = struct.unpack('d', msg)[0] | |
return heartbeat | |
async def driver(): | |
"""Application logic: loop and generate heartbeats.""" | |
heartbeat_wait_s = 5 | |
while True: | |
await asyncio.sleep(heartbeat_wait_s) | |
send_heartbeat() | |
async def handle_request(reader : asyncio.StreamReader, writer: asyncio.StreamWriter): | |
"""Application logic: handle a request. | |
Args: | |
reader: a StreamReader | |
writer: a StreamWriter | |
""" | |
max_msg_size = 1024 | |
msg = await reader.read(max_msg_size) | |
if len(msg) > 0: | |
print(msg) | |
writer.write(b"thanks!\n") | |
await writer.drain() | |
# echo the time since last heartbeat | |
elapsed = time.time() - last_heartbeat() | |
print(f"last_heartbeat: {elapsed:<0.3}s ago") | |
if __name__ == '__main__': | |
"""start the demo""" | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment