Skip to content

Instantly share code, notes, and snippets.

@dustinlennon
Created October 5, 2024 21:02
Show Gist options
  • Save dustinlennon/795b5c283a21b47d002c88ae79e64917 to your computer and use it in GitHub Desktop.
Save dustinlennon/795b5c283a21b47d002c88ae79e64917 to your computer and use it in GitHub Desktop.
a fork and monitor pattern using asyncio and named sockets
"""A fork and monitor pattern, asyncio with named sockets
author: Dustin Lennon
email: dustin.lennon@gmail.com
Generate heartbeats in a parent process; check heartbeats in a child process. The
most recent heartbeat is saved in a file. The child exits when the parent terminates,
specifically when its parent pid changes.
This code also sets up a named socket in the filespace which enables shell interaction with
the monitor process. This should be useful for debugging, e.g:
echo "hello" | socat - UNIX-CLIENT:monitor.socket
echo -n "" | socat - UNIX-CLIENT:monitor.socket
socat - UNIX-CLIENT:monitor.socket
asyncio.open_unix_connection can be a bit fussy with being handed a socket. In particular,
it expects an already accepted socket on which it will block indefinitely if, say, the
client connects and does nothing. So, we provide a `safe_unix_connection` async context
manager to make sure it doesn't get stuck and that the writer is closed appropriately.
Server logic is encapsulated much like the callback function in asyncio.start_server.
"""
import socket
import struct
import asyncio
import os
from pathlib import Path
import time
from contextlib import asynccontextmanager
# The heartbeat file contains the most recent heartbeat
heartbeat_path = "/tmp/monitor.heartbeat"
# A named socket allowing control from the shell
socket_path = "/tmp/monitor.socket"
# The timeout for connecting and read/write operations
connect_timeout_s = 3
# The socket timeout used for accept, this becomes the de facto polling interval to
# check if the parent process has terminated / the child has been orphaned.
socket_timeout_s = 1
@asynccontextmanager
async def safe_unix_connection(conn : socket.socket, timeout : float | None):
"""Call asyncio.open_unix_connection using an accepted socket.
This calls asyncio.open_unix_connection within an asyncio.timeout and yields a tuple
(StreamReader, StreamWriter). On exit, it closes the writer. This addresses the
scenario where a client connects and then does nothing; this should be preempted
to prevent hanging on a subsequent read.
Args:
conn: an accepted socket
timeout: a timeout for the asyncio.timeout call
"""
try:
async with asyncio.timeout(timeout):
reader, writer = await asyncio.open_unix_connection(sock = conn)
yield (reader, writer)
except TimeoutError:
print("monitor: timed out")
pass
finally:
writer.close()
await writer.wait_closed()
async def monitor(ppid : int):
"""Loop over accepted socket connections, create a safe unix connection, and handle request.
The loop ends when the parent process id changes as this is a reliable indicator that the
parent has terminated.
Args:
ppid: the parent process id
"""
Path(socket_path).unlink(missing_ok = True)
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.settimeout(socket_timeout_s)
s.bind(socket_path)
s.listen()
# keep a reference for active tasks
background_tasks = set()
# postcondition: the parent has already terminated
while os.getppid() == ppid:
try:
conn, _ = s.accept()
except TimeoutError:
continue
else:
async with safe_unix_connection(conn, connect_timeout_s) as streams:
task = asyncio.create_task(handle_request(*streams))
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
await asyncio.sleep(0)
# clean up
s.close()
Path(socket_path).unlink(missing_ok = True)
Path(heartbeat_path).unlink(missing_ok = True)
def main():
"""Initialize and fork driver and monitor processes"""
initialize()
ppid = os.getpid()
pid = os.fork()
if pid == 0:
coro = monitor(ppid)
else:
coro = driver()
try:
asyncio.run(coro)
except KeyboardInterrupt:
pass
# start: application-specific logic
def initialize():
"""Send a heartbeat to update the heartbeat file."""
send_heartbeat()
def send_heartbeat():
"""Write the most recent heartbeat timestamp to the heartbeat file."""
with open(heartbeat_path, "wb") as f:
msg = struct.pack('d', time.time())
f.write(msg)
f.flush()
def last_heartbeat():
"""Read the most recent heartbeat timestamp from the heartbeat file."""
with open(heartbeat_path, "rb") as f:
msg = f.read(8)
heartbeat = struct.unpack('d', msg)[0]
return heartbeat
async def driver():
"""Application logic: loop and generate heartbeats."""
heartbeat_wait_s = 5
while True:
await asyncio.sleep(heartbeat_wait_s)
send_heartbeat()
async def handle_request(reader : asyncio.StreamReader, writer: asyncio.StreamWriter):
"""Application logic: handle a request.
Args:
reader: a StreamReader
writer: a StreamWriter
"""
max_msg_size = 1024
msg = await reader.read(max_msg_size)
if len(msg) > 0:
print(msg)
writer.write(b"thanks!\n")
await writer.drain()
# echo the time since last heartbeat
elapsed = time.time() - last_heartbeat()
print(f"last_heartbeat: {elapsed:<0.3}s ago")
if __name__ == '__main__':
"""start the demo"""
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment