Skip to content

Instantly share code, notes, and snippets.

@lloesche
Created August 15, 2023 19:24
Show Gist options
  • Save lloesche/6c03ae3ffb5115852fa53dca7c803f05 to your computer and use it in GitHub Desktop.
Save lloesche/6c03ae3ffb5115852fa53dca7c803f05 to your computer and use it in GitHub Desktop.
Python async server that uses multiple CPU cores while sharing a single socket.
#!/bin/env python3
import signal
import socket
import multiprocessing
import asyncio
import logging
import argparse
from threading import Event
from asyncio import StreamReader, StreamWriter
log_format = f"%(asctime)s|%(levelname)5s|%(process)d|%(threadName)10s %(message)s"
logging.basicConfig(level=logging.WARNING, format=log_format)
log = logging.getLogger("multicoreserver")
log.setLevel(logging.DEBUG)
HTTP_RESPONSE = b"""HTTP/1.1 200 OK
Content-Type: text/plain
Connection: close
Hello, World!
"""
async def handle_client(reader: StreamReader, writer: StreamWriter) -> None:
request: bytes = await reader.read(4096)
if request.startswith(b"GET"):
writer.write(HTTP_RESPONSE)
await writer.drain()
writer.close()
await writer.wait_closed()
async def worker_task(sock: socket.socket, command_queue: multiprocessing.Queue) -> None:
try:
server: asyncio.AbstractServer = await asyncio.start_server(handle_client, sock=sock)
except Exception as e:
log.error(f"Error while starting server: {e}")
return
try:
while True:
if not command_queue.empty():
try:
command = command_queue.get()
except EOFError:
break
except Exception as e:
log.error(f"Error while getting command from queue: {e}")
continue
if command == "shutdown":
break
await asyncio.sleep(1)
finally:
server.close()
await server.wait_closed()
def worker(sock: socket.socket, command_queue: multiprocessing.Queue) -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN)
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(worker_task(sock, command_queue))
finally:
loop.close()
def main() -> None:
parser = argparse.ArgumentParser(description="Multi-core HTTP Server")
parser.add_argument("-p", "--port", type=int, default=8080, help="Port to listen on. (default: 8080)")
parser.add_argument("-n", "--num-workers", type=int, default=16, help="Number of worker processes. (default: 16)")
args = parser.parse_args()
port = args.port
num_workers = args.num_workers
command_queue = multiprocessing.Queue()
shutdown_event = Event()
log.info(f"Starting server on port {port}...")
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
sock.bind(("::", port))
sock.listen(5)
processes = []
def start_new_process():
try:
p: multiprocessing.Process = multiprocessing.Process(target=worker, args=(sock, command_queue))
p.start()
except Exception as e:
log.error(f"Error while starting worker process: {e}")
else:
log.debug(f"Started worker process with PID {p.pid}")
processes.append(p)
for _ in range(num_workers):
start_new_process()
def signal_handler(signum, frame):
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
while not shutdown_event.is_set():
for p in list(processes):
if not p.is_alive():
log.warning(f"Worker process with PID {p.pid} died. Restarting it...")
processes.remove(p)
start_new_process()
shutdown_event.wait(timeout=1)
log.info("Shutting down...")
try:
for _ in processes:
command_queue.put("shutdown")
except Exception as e:
log.error(f"Error while sending shutdown command to worker processes: {e}")
for p in processes:
p.join(timeout=5)
if p.is_alive():
log.warning(f"Worker process with PID {p.pid} did not shut down in time. Forcefully terminating it...")
p.terminate()
log.info("Shutdown complete")
if __name__ == "__main__":
main()
@vvanglro
Copy link

vvanglro commented May 7, 2024

Thank you for your explanation, I went and looked into the signaling mechanism, Ctrl+C is also known as signal.SIGINT, which is translated to KeyboardInterrupt in python

But when I run the code, and I delete the SIGINT in the worker and stop the program, it triggers KeyboardInterrupt. That means that interrupts from the keyboard reach the child process, right?

I see what you mean, your design is such that you notify the task to end gracefully via the process queue, wait up to 5 seconds, and then use SIGTERM to force the process to shut down if it hasn't ended. Your design is fine.

I've made the following simple changes in reference to uvicorn's code:

import argparse
import asyncio
import logging
import multiprocessing
import os
import signal
import socket
import sys
from asyncio import StreamReader, StreamWriter
from threading import Event
from types import FrameType

log_format = f"%(asctime)s|%(levelname)5s|%(process)d|%(threadName)10s  %(message)s"
logging.basicConfig(level=logging.WARNING, format=log_format)

log = logging.getLogger("multicoreserver")
log.setLevel(logging.DEBUG)


HTTP_RESPONSE = b"""HTTP/1.1 200 OK
Content-Type: text/plain
Connection: close
Hello, World!
"""

HANDLED_SIGNALS = (
    signal.SIGINT,  # Unix signal 2. Sent by Ctrl+C.
    signal.SIGTERM,  # Unix signal 15. Sent by `kill <pid>`.
)
if sys.platform == "win32":  # pragma: py-not-win32
    HANDLED_SIGNALS += (signal.SIGBREAK,)  # Windows signal 21. Sent by Ctrl+Break.


class Server:

    def __init__(self):
        self.should_exit = False

    async def handle_client(self, reader: StreamReader, writer: StreamWriter) -> None:
        request: bytes = await reader.read(4096)

        if request.startswith(b"GET"):
            writer.write(HTTP_RESPONSE)
            await writer.drain()

        writer.close()
        await writer.wait_closed()

    async def worker_task(self, sock: socket.socket):
        try:
            server = await asyncio.start_server(self.handle_client, sock=sock)
        except Exception as e:
            log.error(f"Error while starting server: {e}")
            return
        while not self.should_exit:
            await asyncio.sleep(0.5)
        server.close()
        sock.close()
        await server.wait_closed()

    def handle_exit(self, sig: int, frame: FrameType) -> None:
        self.should_exit = True

    def worker(self, sock: socket.socket) -> None:
        process_id = os.getpid()
        loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
            for sig in HANDLED_SIGNALS:
                loop.add_signal_handler(sig, self.handle_exit, sig, None)
        except NotImplementedError:  # pragma: no cover
            # Windows
            for sig in HANDLED_SIGNALS:
                signal.signal(sig, self.handle_exit)
        log.info("Started server process [%d]", process_id)
        try:
            loop.run_until_complete(self.worker_task(sock))
        finally:
            loop.close()
        log.info("Finished server process [%d]", process_id)

    def main(self) -> None:
        parser = argparse.ArgumentParser(description="Multi-core HTTP Server")
        parser.add_argument("-p", "--port", type=int, default=8080, help="Port to listen on. (default: 8080)")
        parser.add_argument("-n", "--num-workers", type=int, default=16, help="Number of worker processes. (default: 16)")
        args = parser.parse_args()

        port = args.port
        num_workers = args.num_workers
        shutdown_event = Event()

        log.info(f"Starting server on port {port}...")

        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
            sock.bind(("::", port))
            sock.listen(5)
            processes = []

            def start_new_process():
                try:
                    p: multiprocessing.Process = multiprocessing.Process(target=self.worker, args=(sock,))
                    p.start()
                except Exception as e:
                    log.error(f"Error while starting worker process: {e}")
                else:
                    processes.append(p)

            def signal_handler(signum, frame):
                shutdown_event.set()

            for sig in HANDLED_SIGNALS:
                signal.signal(sig, signal_handler)

            for _idx in range(num_workers):
                start_new_process()

            while not shutdown_event.is_set():
                for p in list(processes):
                    if not p.is_alive():
                        log.warning(f"Worker process with PID {p.pid} died. Restarting it...")
                        processes.remove(p)
                        start_new_process()
                shutdown_event.wait(timeout=1)

            log.info("Shutting down...")

            for p in processes:
                p.terminate()
                p.join()
            log.info("Shutdown complete")
            log.info(f"Stopping parent process [{str(os.getpid())}]")


if __name__ == "__main__":
    Server().main()

@lloesche
Copy link
Author

lloesche commented May 8, 2024

Ah I see, that makes sense. I have no idea how Windows interprets these signals or what other signals are being send on a Windows terminal. I've only got access to Linux and BSD systems. I would think that those try/except NotImplementedError blocks aren't necessary as you already gate the adding of SIGBREAK on win32 via the if sys.platform check, but I guess they don't hurt either.

I'm wondering, is SO_REUSEADDR even an option on Windows? I thought it was a Linux thing. Like does the rest of the code work as expected?

@lloesche
Copy link
Author

lloesche commented May 8, 2024

Just answered my own question via https://learn.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
Looks like SO_REUSEADDR has been around since Windows 95.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment