Last active
May 29, 2023 19:16
-
-
Save bobjansen/0ad44c508baa9799937b39d3b6c6485a to your computer and use it in GitHub Desktop.
How to persist orders with minimum overhead?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A simple experiment to time different ways to save orders | |
Start the servers with | |
> python main.py server 5 | |
and start the experiment with | |
> python main.py client 5 | |
Servers can be reused. | |
Example output: | |
Connecting to 5 exchanges with sync writer | |
Connecting to 8080 | |
Connecting to 8081 | |
Connecting to 8082 | |
Connecting to 8083 | |
Connecting to 8084 | |
Took: 7.550986182s | |
Took: 7.557617441s | |
Took: 7.563337804s | |
Took: 7.568995488s | |
Took: 7.572868576s | |
Connecting to 5 exchanges with async writer | |
Connecting to 8080 | |
Connecting to 8081 | |
Connecting to 8082 | |
Connecting to 8083 | |
Connecting to 8084 | |
Took: 3.553296876s | |
Took: 3.571782329s | |
Took: 3.574900966s | |
Took: 3.58377087s | |
Took: 3.639079402s | |
Connecting to 5 exchanges with NoopConsumer | |
Connecting to 8080 | |
Connecting to 8081 | |
Connecting to 8082 | |
Connecting to 8083 | |
Connecting to 8084 | |
Took: 3.632937903s | |
Took: 3.647552666s | |
Took: 3.67249288s | |
Took: 3.673100654s | |
Took: 3.708134698s | |
Connecting to 5 exchanges with SqliteConsumer synced writes | |
Connecting to 8080 | |
Connecting to 8081 | |
Connecting to 8082 | |
Connecting to 8083 | |
Connecting to 8084 | |
Took: 24.687743786s | |
Took: 24.687742097s | |
Took: 24.688043714s | |
Took: 24.688395203s | |
Took: 24.689093151s | |
Connecting to 5 exchanges with SqliteConsumer async writes | |
Connecting to 8080 | |
Connecting to 8081 | |
Connecting to 8082 | |
Connecting to 8083 | |
Connecting to 8084 | |
Took: 25.103373169s | |
Took: 25.104250181s | |
Took: 25.113094467s | |
Took: 25.120974656s | |
Took: 25.120453852s | |
""" | |
import asyncio | |
import gc | |
import hashlib | |
import sqlite3 | |
import sys | |
import time | |
import numpy as np | |
from websockets.server import serve | |
from websockets import connect | |
DEBUG = False | |
# Server | |
hostname = "localhost" | |
base_port = 8080 | |
average_wait_time = 2 | |
num_orders = 5000 | |
p = 0.25 | |
speed = 10000 | |
seed = 42 | |
# Client | |
sync_overhead = 0.001 | |
async_overhead = 0.001 | |
np.random.seed(seed) | |
hasher = hashlib.sha256(); | |
def print_settings(): | |
print( | |
f"""Settings: | |
hostname: {hostname} | |
base_port: {base_port} | |
num_orders: {num_orders} | |
p: {p} | |
speed: {speed} | |
seed: {seed}""" | |
) | |
class NoopConsumer: | |
"""Takes a message and does nothing""" | |
def write(self, exchange_id, order_id): | |
pass | |
class Consumer: | |
"""Consumes and writes a message and adds some overhead""" | |
def __init__(self, overhead): | |
self.overhead = overhead | |
self.messages_received = [] | |
def write(self, exchange_id, order_id): | |
self.messages_received.append(f"{exchange_id}|{order_id}") | |
time.sleep(self.overhead) | |
async def async_write(self, exchange_id, order_id): | |
self.messages_received.append(f"{exchange_id}|{order_id}") | |
await asyncio.sleep(self.overhead) | |
class SqliteConsumer: | |
"""Non-thread safe writer to a sqlite3 in 'test_db.sqlite3'""" | |
def __init__(self): | |
self.con = sqlite3.connect("test_db.sqlite3") | |
self.cur = self.con.cursor() | |
self.cur.execute("DROP TABLE IF EXISTS orders") | |
self.cur.execute("CREATE TABLE orders(exchange_id int, order_id int)") | |
def write(self, exchange_id, order_id): | |
self.cur.execute( | |
f"INSERT INTO orders (exchange_id, order_id) VALUES ({exchange_id}, {order_id})" | |
) | |
self.con.commit() | |
async def async_write(self, exchange_id, order_id): | |
self.cur.execute( | |
f"INSERT INTO orders (exchange_id, order_id) VALUES ({exchange_id}, {order_id})" | |
) | |
self.con.commit() | |
class Exchange: | |
""" | |
Bare functionality to model an exchange | |
An exchange listens on hostname:port and sends exchange_id:order_id over a | |
websocket at random intervals. | |
""" | |
def __init__(self, exchange_id, hostname, port, average_wait_time): | |
self.exchange_id = exchange_id | |
self.hostname = hostname | |
self.port = port | |
self.wait_times = np.random.poisson(average_wait_time, num_orders) | |
print(f"Creating {self.exchange_id} on {self.hostname}:{self.port}") | |
print(f"Average wait time param: {average_wait_time}") | |
print(f"Average wait time: {self.wait_times.sum() / (speed * num_orders)}") | |
print(f"Total wait time: {self.wait_times.sum() / speed}") | |
async def order_feed(self, websocket): | |
async for message in websocket: | |
if message == "start": | |
start = time.time_ns() | |
send_time = 0 | |
for i, wait_time in enumerate(self.wait_times): | |
await asyncio.sleep(wait_time / speed) | |
start_send = time.time_ns() | |
await websocket.send(f"{self.exchange_id}:{i}") | |
send_time += time.time_ns() - start_send | |
await websocket.send("done") | |
print( | |
f"Took {(time.time_ns() - start) / 1e9}s to send all orders on {self.exchange_id}" | |
) | |
print(f"Total send time: {send_time / 1e9}s") | |
print_settings() | |
async def run(self): | |
async with serve(self.order_feed, self.hostname, self.port): | |
await asyncio.Future() | |
class Client: | |
"""Connects to an exchange and records some results | |
The feed is started with 'start' and stopped when the message 'done' is | |
recieved. A coin flip is performed to decide whether the message is saved. | |
""" | |
def __init__(self, port, writer, write_async): | |
self.port = port | |
self.writer = writer | |
self.write_async = write_async | |
async def run(self): | |
print(f"Connecting to {self.port}") | |
async with connect(f"ws://{hostname}:{self.port}") as websocket: | |
await websocket.send("start") | |
start = time.time_ns() | |
async for message in websocket: | |
if message == "done": | |
break | |
if np.random.uniform(0, 1) < p: | |
exchange_id, order_id = message.split(":") | |
if self.write_async: | |
await self.writer.async_write(exchange_id, order_id) | |
else: | |
self.writer.write(exchange_id, order_id) | |
print(f"Took: {(time.time_ns() - start) / 1e9}s") | |
class Strategy: | |
"""A strategy holds mulitple connections""" | |
def __init__(self, clients): | |
self.clients = clients | |
async def run_all(self): | |
gc.collect() | |
async with asyncio.TaskGroup() as tg: | |
for client in self.clients: | |
tg.create_task(client.run()) | |
if __name__ == "__main__": | |
if len(sys.argv) < 3: | |
print("Provide either 'server' or 'client' as argument and a count") | |
else: | |
arg = sys.argv[1] | |
num_servers = int(sys.argv[2]) | |
if arg == "server": | |
print_settings() | |
async def run_servers(num_servers): | |
async with asyncio.TaskGroup() as tg: | |
for i in range(num_servers): | |
exchange = Exchange( | |
i + 1, hostname, base_port + i, average_wait_time | |
) | |
tg.create_task(exchange.run()) | |
asyncio.run(run_servers(num_servers)) | |
elif arg == "client": | |
print(f"Connecting to {num_servers} exchanges with sync writer") | |
writer = Consumer(sync_overhead) | |
clients = [Client(base_port + i, writer, False) for i in range(num_servers)] | |
asyncio.run(Strategy(clients).run_all()) | |
if DEBUG: | |
print("\n".join(strategy.writer.messages_received)) | |
print(f"Connecting to {num_servers} exchanges with async writer") | |
writer = Consumer(async_overhead) | |
clients = [Client(base_port + i, writer, True) for i in range(num_servers)] | |
asyncio.run(Strategy(clients).run_all()) | |
if DEBUG: | |
print("\n".join(strategy.writer.messages_received)) | |
print(f"Connecting to {num_servers} exchanges with NoopConsumer") | |
writer = NoopConsumer() | |
clients = [Client(base_port + i, writer, False) for i in range(num_servers)] | |
asyncio.run(Strategy(clients).run_all()) | |
if DEBUG: | |
print("\n".join(strategy.writer.messages_received)) | |
print( | |
f"Connecting to {num_servers} exchanges with SqliteConsumer synced writes" | |
) | |
writer = SqliteConsumer() | |
clients = [Client(base_port + i, writer, False) for i in range(num_servers)] | |
asyncio.run(Strategy(clients).run_all()) | |
print( | |
f"Connecting to {num_servers} exchanges with SqliteConsumer async writes" | |
) | |
writer = SqliteConsumer() | |
clients = [Client(base_port + i, writer, True) for i in range(num_servers)] | |
asyncio.run(Strategy(clients).run_all()) | |
else: | |
print(f"Unknown arg '{arg}', exiting") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment