Skip to content

Instantly share code, notes, and snippets.

@lloesche
Created August 15, 2023 19:24
Show Gist options
  • Save lloesche/6c03ae3ffb5115852fa53dca7c803f05 to your computer and use it in GitHub Desktop.
Save lloesche/6c03ae3ffb5115852fa53dca7c803f05 to your computer and use it in GitHub Desktop.
Python async server that uses multiple CPU cores while sharing a single socket.
#!/bin/env python3
import signal
import socket
import multiprocessing
import asyncio
import logging
import argparse
from threading import Event
from asyncio import StreamReader, StreamWriter
log_format = f"%(asctime)s|%(levelname)5s|%(process)d|%(threadName)10s %(message)s"
logging.basicConfig(level=logging.WARNING, format=log_format)
log = logging.getLogger("multicoreserver")
log.setLevel(logging.DEBUG)
HTTP_RESPONSE = b"""HTTP/1.1 200 OK
Content-Type: text/plain
Connection: close
Hello, World!
"""
async def handle_client(reader: StreamReader, writer: StreamWriter) -> None:
request: bytes = await reader.read(4096)
if request.startswith(b"GET"):
writer.write(HTTP_RESPONSE)
await writer.drain()
writer.close()
await writer.wait_closed()
async def worker_task(sock: socket.socket, command_queue: multiprocessing.Queue) -> None:
try:
server: asyncio.AbstractServer = await asyncio.start_server(handle_client, sock=sock)
except Exception as e:
log.error(f"Error while starting server: {e}")
return
try:
while True:
if not command_queue.empty():
try:
command = command_queue.get()
except EOFError:
break
except Exception as e:
log.error(f"Error while getting command from queue: {e}")
continue
if command == "shutdown":
break
await asyncio.sleep(1)
finally:
server.close()
await server.wait_closed()
def worker(sock: socket.socket, command_queue: multiprocessing.Queue) -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN)
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(worker_task(sock, command_queue))
finally:
loop.close()
def main() -> None:
parser = argparse.ArgumentParser(description="Multi-core HTTP Server")
parser.add_argument("-p", "--port", type=int, default=8080, help="Port to listen on. (default: 8080)")
parser.add_argument("-n", "--num-workers", type=int, default=16, help="Number of worker processes. (default: 16)")
args = parser.parse_args()
port = args.port
num_workers = args.num_workers
command_queue = multiprocessing.Queue()
shutdown_event = Event()
log.info(f"Starting server on port {port}...")
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
sock.bind(("::", port))
sock.listen(5)
processes = []
def start_new_process():
try:
p: multiprocessing.Process = multiprocessing.Process(target=worker, args=(sock, command_queue))
p.start()
except Exception as e:
log.error(f"Error while starting worker process: {e}")
else:
log.debug(f"Started worker process with PID {p.pid}")
processes.append(p)
for _ in range(num_workers):
start_new_process()
def signal_handler(signum, frame):
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
while not shutdown_event.is_set():
for p in list(processes):
if not p.is_alive():
log.warning(f"Worker process with PID {p.pid} died. Restarting it...")
processes.remove(p)
start_new_process()
shutdown_event.wait(timeout=1)
log.info("Shutting down...")
try:
for _ in processes:
command_queue.put("shutdown")
except Exception as e:
log.error(f"Error while sending shutdown command to worker processes: {e}")
for p in processes:
p.join(timeout=5)
if p.is_alive():
log.warning(f"Worker process with PID {p.pid} did not shut down in time. Forcefully terminating it...")
p.terminate()
log.info("Shutdown complete")
if __name__ == "__main__":
main()
@vvanglro
Copy link

vvanglro commented May 6, 2024

Maybe it should be.

def worker(sock: socket.socket, command_queue: multiprocessing.Queue) -> None:
-   signal.signal(signal.SIGINT, signal.SIG_IGN)
    loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    try:
        loop.run_until_complete(worker_task(sock, command_queue))
+   except KeyboardInterrupt:
+       log.info("Received SIGINT, shutting down...")
    finally:
        loop.close()

@lloesche
Copy link
Author

lloesche commented May 6, 2024

Uhm, no. The worker functions run in separate processes and the line you deleted is not related to keyboard interrupts. It just makes sure the forked processes ignore SIGINT and only shut down on SIGTERM (which they receive via p.terminate() when the parent process shuts down, if the worker didn't shut down cleanly by receiving the shutdown command via the queue).
Keyboard interrupts on your main tty wouldn't even reach the forked child processes. So excepting KeyboardInterrupt inside the worker() function makes no sense as it'll never be connected to a tty from which it could receive them.

@vvanglro
Copy link

vvanglro commented May 7, 2024

Thank you for your explanation, I went and looked into the signaling mechanism, Ctrl+C is also known as signal.SIGINT, which is translated to KeyboardInterrupt in python

But when I run the code, and I delete the SIGINT in the worker and stop the program, it triggers KeyboardInterrupt. That means that interrupts from the keyboard reach the child process, right?

I see what you mean, your design is such that you notify the task to end gracefully via the process queue, wait up to 5 seconds, and then use SIGTERM to force the process to shut down if it hasn't ended. Your design is fine.

I've made the following simple changes in reference to uvicorn's code:

import argparse
import asyncio
import logging
import multiprocessing
import os
import signal
import socket
import sys
from asyncio import StreamReader, StreamWriter
from threading import Event
from types import FrameType

log_format = f"%(asctime)s|%(levelname)5s|%(process)d|%(threadName)10s  %(message)s"
logging.basicConfig(level=logging.WARNING, format=log_format)

log = logging.getLogger("multicoreserver")
log.setLevel(logging.DEBUG)


HTTP_RESPONSE = b"""HTTP/1.1 200 OK
Content-Type: text/plain
Connection: close
Hello, World!
"""

HANDLED_SIGNALS = (
    signal.SIGINT,  # Unix signal 2. Sent by Ctrl+C.
    signal.SIGTERM,  # Unix signal 15. Sent by `kill <pid>`.
)
if sys.platform == "win32":  # pragma: py-not-win32
    HANDLED_SIGNALS += (signal.SIGBREAK,)  # Windows signal 21. Sent by Ctrl+Break.


class Server:

    def __init__(self):
        self.should_exit = False

    async def handle_client(self, reader: StreamReader, writer: StreamWriter) -> None:
        request: bytes = await reader.read(4096)

        if request.startswith(b"GET"):
            writer.write(HTTP_RESPONSE)
            await writer.drain()

        writer.close()
        await writer.wait_closed()

    async def worker_task(self, sock: socket.socket):
        try:
            server = await asyncio.start_server(self.handle_client, sock=sock)
        except Exception as e:
            log.error(f"Error while starting server: {e}")
            return
        while not self.should_exit:
            await asyncio.sleep(0.5)
        server.close()
        sock.close()
        await server.wait_closed()

    def handle_exit(self, sig: int, frame: FrameType) -> None:
        self.should_exit = True

    def worker(self, sock: socket.socket) -> None:
        process_id = os.getpid()
        loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            for sig in HANDLED_SIGNALS:
                loop.add_signal_handler(sig, self.handle_exit, sig, None)
        except NotImplementedError:  # pragma: no cover
            # Windows
            for sig in HANDLED_SIGNALS:
                signal.signal(sig, self.handle_exit)
        log.info("Started server process [%d]", process_id)
        try:
            loop.run_until_complete(self.worker_task(sock))
        finally:
            loop.close()
        log.info("Finished server process [%d]", process_id)

    def main(self) -> None:
        parser = argparse.ArgumentParser(description="Multi-core HTTP Server")
        parser.add_argument("-p", "--port", type=int, default=8080, help="Port to listen on. (default: 8080)")
        parser.add_argument("-n", "--num-workers", type=int, default=16, help="Number of worker processes. (default: 16)")
        args = parser.parse_args()

        port = args.port
        num_workers = args.num_workers
        shutdown_event = Event()

        log.info(f"Starting server on port {port}...")

        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
            sock.bind(("::", port))
            sock.listen(5)
            processes = []

            def start_new_process():
                try:
                    p: multiprocessing.Process = multiprocessing.Process(target=self.worker, args=(sock,))
                    p.start()
                except Exception as e:
                    log.error(f"Error while starting worker process: {e}")
                else:
                    processes.append(p)

            def signal_handler(signum, frame):
                shutdown_event.set()

            for sig in HANDLED_SIGNALS:
                signal.signal(sig, signal_handler)

            for _idx in range(num_workers):
                start_new_process()

            while not shutdown_event.is_set():
                for p in list(processes):
                    if not p.is_alive():
                        log.warning(f"Worker process with PID {p.pid} died. Restarting it...")
                        processes.remove(p)
                        start_new_process()
                shutdown_event.wait(timeout=1)

            log.info("Shutting down...")

            for p in processes:
                p.terminate()
                p.join()
            log.info("Shutdown complete")
            log.info(f"Stopping parent process [{str(os.getpid())}]")


if __name__ == "__main__":
    Server().main()

@lloesche
Copy link
Author

lloesche commented May 8, 2024

Ah I see, that makes sense. I have no idea how Windows interprets these signals or what other signals are being send on a Windows terminal. I've only got access to Linux and BSD systems. I would think that those try/except NotImplementedError blocks aren't necessary as you already gate the adding of SIGBREAK on win32 via the if sys.platform check, but I guess they don't hurt either.

I'm wondering, is SO_REUSEADDR even an option on Windows? I thought it was a Linux thing. Like does the rest of the code work as expected?

@lloesche
Copy link
Author

lloesche commented May 8, 2024

Just answered my own question via https://learn.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
Looks like SO_REUSEADDR has been around since Windows 95.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment