-
-
Save lloesche/6c03ae3ffb5115852fa53dca7c803f05 to your computer and use it in GitHub Desktop.
#!/bin/env python3 | |
import signal | |
import socket | |
import multiprocessing | |
import asyncio | |
import logging | |
import argparse | |
from threading import Event | |
from asyncio import StreamReader, StreamWriter | |
log_format = f"%(asctime)s|%(levelname)5s|%(process)d|%(threadName)10s %(message)s" | |
logging.basicConfig(level=logging.WARNING, format=log_format) | |
log = logging.getLogger("multicoreserver") | |
log.setLevel(logging.DEBUG) | |
HTTP_RESPONSE = b"""HTTP/1.1 200 OK | |
Content-Type: text/plain | |
Connection: close | |
Hello, World! | |
""" | |
async def handle_client(reader: StreamReader, writer: StreamWriter) -> None: | |
request: bytes = await reader.read(4096) | |
if request.startswith(b"GET"): | |
writer.write(HTTP_RESPONSE) | |
await writer.drain() | |
writer.close() | |
await writer.wait_closed() | |
async def worker_task(sock: socket.socket, command_queue: multiprocessing.Queue) -> None: | |
try: | |
server: asyncio.AbstractServer = await asyncio.start_server(handle_client, sock=sock) | |
except Exception as e: | |
log.error(f"Error while starting server: {e}") | |
return | |
try: | |
while True: | |
if not command_queue.empty(): | |
try: | |
command = command_queue.get() | |
except EOFError: | |
break | |
except Exception as e: | |
log.error(f"Error while getting command from queue: {e}") | |
continue | |
if command == "shutdown": | |
break | |
await asyncio.sleep(1) | |
finally: | |
server.close() | |
await server.wait_closed() | |
def worker(sock: socket.socket, command_queue: multiprocessing.Queue) -> None: | |
signal.signal(signal.SIGINT, signal.SIG_IGN) | |
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
loop.run_until_complete(worker_task(sock, command_queue)) | |
finally: | |
loop.close() | |
def main() -> None: | |
parser = argparse.ArgumentParser(description="Multi-core HTTP Server") | |
parser.add_argument("-p", "--port", type=int, default=8080, help="Port to listen on. (default: 8080)") | |
parser.add_argument("-n", "--num-workers", type=int, default=16, help="Number of worker processes. (default: 16)") | |
args = parser.parse_args() | |
port = args.port | |
num_workers = args.num_workers | |
command_queue = multiprocessing.Queue() | |
shutdown_event = Event() | |
log.info(f"Starting server on port {port}...") | |
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: | |
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) | |
sock.bind(("::", port)) | |
sock.listen(5) | |
processes = [] | |
def start_new_process(): | |
try: | |
p: multiprocessing.Process = multiprocessing.Process(target=worker, args=(sock, command_queue)) | |
p.start() | |
except Exception as e: | |
log.error(f"Error while starting worker process: {e}") | |
else: | |
log.debug(f"Started worker process with PID {p.pid}") | |
processes.append(p) | |
for _ in range(num_workers): | |
start_new_process() | |
def signal_handler(signum, frame): | |
shutdown_event.set() | |
signal.signal(signal.SIGINT, signal_handler) | |
signal.signal(signal.SIGTERM, signal_handler) | |
while not shutdown_event.is_set(): | |
for p in list(processes): | |
if not p.is_alive(): | |
log.warning(f"Worker process with PID {p.pid} died. Restarting it...") | |
processes.remove(p) | |
start_new_process() | |
shutdown_event.wait(timeout=1) | |
log.info("Shutting down...") | |
try: | |
for _ in processes: | |
command_queue.put("shutdown") | |
except Exception as e: | |
log.error(f"Error while sending shutdown command to worker processes: {e}") | |
for p in processes: | |
p.join(timeout=5) | |
if p.is_alive(): | |
log.warning(f"Worker process with PID {p.pid} did not shut down in time. Forcefully terminating it...") | |
p.terminate() | |
log.info("Shutdown complete") | |
if __name__ == "__main__": | |
main() |
Thank you for your explanation, I went and looked into the signaling mechanism, Ctrl+C is also known as signal.SIGINT, which is translated to KeyboardInterrupt in python
But when I run the code, and I delete the SIGINT in the worker and stop the program, it triggers KeyboardInterrupt. That means that interrupts from the keyboard reach the child process, right?
I see what you mean, your design is such that you notify the task to end gracefully via the process queue, wait up to 5 seconds, and then use SIGTERM to force the process to shut down if it hasn't ended. Your design is fine.
I've made the following simple changes in reference to uvicorn's code:
import argparse
import asyncio
import logging
import multiprocessing
import os
import signal
import socket
import sys
from asyncio import StreamReader, StreamWriter
from threading import Event
from types import FrameType
log_format = f"%(asctime)s|%(levelname)5s|%(process)d|%(threadName)10s %(message)s"
logging.basicConfig(level=logging.WARNING, format=log_format)
log = logging.getLogger("multicoreserver")
log.setLevel(logging.DEBUG)
HTTP_RESPONSE = b"""HTTP/1.1 200 OK
Content-Type: text/plain
Connection: close
Hello, World!
"""
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)
if sys.platform == "win32": # pragma: py-not-win32
HANDLED_SIGNALS += (signal.SIGBREAK,) # Windows signal 21. Sent by Ctrl+Break.
class Server:
def __init__(self):
self.should_exit = False
async def handle_client(self, reader: StreamReader, writer: StreamWriter) -> None:
request: bytes = await reader.read(4096)
if request.startswith(b"GET"):
writer.write(HTTP_RESPONSE)
await writer.drain()
writer.close()
await writer.wait_closed()
async def worker_task(self, sock: socket.socket):
try:
server = await asyncio.start_server(self.handle_client, sock=sock)
except Exception as e:
log.error(f"Error while starting server: {e}")
return
while not self.should_exit:
await asyncio.sleep(0.5)
server.close()
sock.close()
await server.wait_closed()
def handle_exit(self, sig: int, frame: FrameType) -> None:
self.should_exit = True
def worker(self, sock: socket.socket) -> None:
process_id = os.getpid()
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError: # pragma: no cover
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
log.info("Started server process [%d]", process_id)
try:
loop.run_until_complete(self.worker_task(sock))
finally:
loop.close()
log.info("Finished server process [%d]", process_id)
def main(self) -> None:
parser = argparse.ArgumentParser(description="Multi-core HTTP Server")
parser.add_argument("-p", "--port", type=int, default=8080, help="Port to listen on. (default: 8080)")
parser.add_argument("-n", "--num-workers", type=int, default=16, help="Number of worker processes. (default: 16)")
args = parser.parse_args()
port = args.port
num_workers = args.num_workers
shutdown_event = Event()
log.info(f"Starting server on port {port}...")
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
sock.bind(("::", port))
sock.listen(5)
processes = []
def start_new_process():
try:
p: multiprocessing.Process = multiprocessing.Process(target=self.worker, args=(sock,))
p.start()
except Exception as e:
log.error(f"Error while starting worker process: {e}")
else:
processes.append(p)
def signal_handler(signum, frame):
shutdown_event.set()
for sig in HANDLED_SIGNALS:
signal.signal(sig, signal_handler)
for _idx in range(num_workers):
start_new_process()
while not shutdown_event.is_set():
for p in list(processes):
if not p.is_alive():
log.warning(f"Worker process with PID {p.pid} died. Restarting it...")
processes.remove(p)
start_new_process()
shutdown_event.wait(timeout=1)
log.info("Shutting down...")
for p in processes:
p.terminate()
p.join()
log.info("Shutdown complete")
log.info(f"Stopping parent process [{str(os.getpid())}]")
if __name__ == "__main__":
Server().main()
Ah I see, that makes sense. I have no idea how Windows interprets these signals or what other signals are being send on a Windows terminal. I've only got access to Linux and BSD systems. I would think that those try/except NotImplementedError
blocks aren't necessary as you already gate the adding of SIGBREAK
on win32 via the if sys.platform
check, but I guess they don't hurt either.
I'm wondering, is SO_REUSEADDR even an option on Windows? I thought it was a Linux thing. Like does the rest of the code work as expected?
Just answered my own question via https://learn.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
Looks like SO_REUSEADDR
has been around since Windows 95.
Uhm, no. The worker functions run in separate processes and the line you deleted is not related to keyboard interrupts. It just makes sure the forked processes ignore
SIGINT
and only shut down onSIGTERM
(which they receive viap.terminate()
when the parent process shuts down, if the worker didn't shut down cleanly by receiving theshutdown
command via the queue).Keyboard interrupts on your main tty wouldn't even reach the forked child processes. So excepting
KeyboardInterrupt
inside theworker()
function makes no sense as it'll never be connected to a tty from which it could receive them.