Skip to content

Instantly share code, notes, and snippets.

@benweint
Created March 11, 2023 06:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benweint/e43f73ec5b2bc098921150ea0613987b to your computer and use it in GitHub Desktop.
Save benweint/e43f73ec5b2bc098921150ea0613987b to your computer and use it in GitHub Desktop.
Connection pool metastability simulation
#!/usr/bin/env python
import simpy
import random
import math
import pandas
import matplotlib
import matplotlib.pyplot as plt
# Represents a shared database instance with a fixed CPU allocation.
# CPU is modelled as a simpy.Resource with capacity equal to the number of cores.
# DB operations are simulated by calls to execute(), and hold a single core for
# the duration of their execution.
class DB:
def __init__(self, env, ncpu, connect_cost):
self.env = env
self.ncpu = ncpu
self.cpu = simpy.Resource(env, ncpu)
self.connect_cost = connect_cost
self.total_cpu_time = 0
self.cpu_time_start = env.now
def execute(self, cost, cancel_event=None):
start = self.env.now
with self.cpu.request() as req:
yield req
yield self.env.timeout(cost)
self.total_cpu_time += min(self.env.now - start, cost)
def reset_stats(self):
db.total_cpu_time = 0
db.cpu_time_start = self.env.now
# Represents a single connection slot within a connection pool.
# Connections start off in the ready state, and may be either ready or dead.
class Connection:
def __init__(self, state='ready'):
self.state = state
def is_ready(self):
return self.state == 'ready'
def mark_dead(self):
self.state = 'dead'
def mark_ready(self):
self.state = 'ready'
# Represents a single application instance.
# Each instance maintains a pool of connections of size pool_size.
# - net_latency models the networking latency between the application and the DB
# - request_timeout is the time after which a request will be aborted by the application
# - trip_threshold is the number of consecuritve query failures required to trip the circuit breaker
# - max_trip_duration is the maximum time that the breaker will remain tripped for before allowing through
# another query.
class App:
def __init__(self, env, db, pool_size, net_latency, request_timeout, enable_circuit_breaker, trip_threshold, max_trip_duration):
self.env = env
self.db = db
self.conns = simpy.Store(env, capacity=pool_size)
self.net_latency = net_latency
self.timeout = request_timeout
self.reset_stats()
# Circuit breaker settings and state
self.enable_circuit_breaker = enable_circuit_breaker
self.trip_threshold = trip_threshold
self.max_trip_duration = max_trip_duration
self.consecutive_failures = 0 # Number of consecutive observed query failures in this app instance
self.tripped_until = 0 # Time after which this circuit breaker instance will no longer be considered tripped
for i in range(pool_size):
self.conns.put(Connection())
# Handle a single simulated request by:
#
# 0. Aborting immediately if the circuit breaker is open
# 1. Obtaining a connection slot from the pool
# 2. Refreshing the connection if needed
# 3. Issuing a query (will hold the DB CPU for query_cost seconds)
# 4. Returning the connection to the pool
#
# A timeout during steps 2 or 3 will cause the connection to be marked as dead
# and returned to the pool, where the next user will need to refresh it before
# use.
#
def handle_request(self, query_cost):
begin = self.env.now
# Step 0 - Abort immediately if the circuit breaker is open
if not self.check_breaker():
self.total_requests += 1
self.aborts += 1
return
conn = yield self.env.process(self.get_connection(begin))
if not conn:
return
if not conn.is_ready():
ok = yield self.env.process(self.execute(begin, conn, self.db.connect_cost))
if not ok:
return
conn.mark_ready()
ok = yield self.env.process(self.execute(begin, conn, query_cost))
if not ok:
return
yield self.conns.put(conn)
self.total_requests += 1
self.successes += 1
if self.consecutive_failures > 0:
self.consecutive_failures = 0
self.note_latency(begin)
self.total_latency += self.env.now - begin
def note_latency(self, begin: float) -> None:
latency = self.env.now - begin
self.total_latency += latency
def check_breaker(self) -> bool:
if not self.enable_circuit_breaker:
return True
if self.consecutive_failures > self.trip_threshold and self.tripped_until == 0:
if self.consecutive_failures > 1000:
trip_duration = self.max_trip_duration
else:
trip_duration = min((2 ** self.consecutive_failures) / 1000.0, self.max_trip_duration)
self.tripped_until = self.env.now + trip_duration
if self.env.now > self.tripped_until:
self.tripped_until = 0
return True
return False
def get_connection(self, begin: float) -> Connection:
get_conn = self.conns.get()
timeout = self.env.timeout(self.timeout)
result = yield get_conn | timeout
if timeout in result:
get_conn.cancel()
return self.note_timeout(begin)
return result[get_conn]
def execute(self, begin: float, conn: Connection, cost: float) -> bool:
remaining = self.timeout - (self.env.now - begin)
timeout = self.env.timeout(remaining)
net_latency = self.env.timeout(self.net_latency)
result = yield net_latency | timeout
if timeout in result:
return self.note_timeout(begin, conn)
timeout = self.env.timeout(self.timeout - (env.now - begin))
query = self.env.process(db.execute(cost, timeout))
result = yield query | timeout
if timeout in result:
return self.note_timeout(begin, conn)
return True
def note_timeout(self, begin, conn=None):
if conn:
conn.mark_dead()
self.conns.put(conn)
self.total_requests += 1
self.timeouts += 1
self.consecutive_failures += 1
self.note_latency(begin)
return False
def reset_stats(self):
self.aborts = 0
self.successes = 0
self.timeouts = 0
self.total_requests = 0
self.total_latency = 0
def track_stats(env, db, apps, target_tput, rows, interval, breaker_enabled):
while True:
yield env.timeout(interval)
total_successes = sum(a.successes for a in apps)
total_requests = sum(a.total_requests for a in apps)
total_latency = sum(a.total_latency for a in apps)
total_timeouts = sum(a.timeouts for a in apps)
total_aborts = sum(a.aborts for a in apps)
db_cpu_utilization = db.total_cpu_time / ((env.now - db.cpu_time_start) * db.ncpu)
for a in apps:
a.reset_stats()
db.reset_stats()
if total_requests == 0:
avg_latency = 0
else:
avg_latency = total_latency / total_requests
row = [
env.now,
target_tput,
total_requests/interval,
total_successes/interval,
total_aborts,
avg_latency,
db_cpu_utilization,
breaker_enabled
]
rows.append(row)
def generate_load(env, apps, rps, query_cost):
while True:
a = random.choice(apps)
env.process(a.handle_request(query_cost))
delay = random.expovariate(rps)
yield env.timeout(delay)
# Cause all queries to timeout by holding ncpu database CPUs for duration seconds
# after the given delay.
def nudge(env, db, delay, ncpu, duration, interval):
yield env.timeout(delay)
remaining = duration
while remaining > 0:
d = min(remaining, interval)
for i in range(ncpu):
env.process(db.execute(d))
yield env.timeout(d)
remaining -= d
print(f'Done at {env.now}')
num_instances = 3
pool_size = 10
query_cost = 0.001
connect_cost = 0.005
net_latency = 0.001
db_cpus = 8
request_timeout = 2.0
trip_threshold = 5
max_trip_duration = 1
sim_duration = 70
interval = 0.3
nudge_delay = 10
nudge_duration = 3
all_rows = []
rps_values = [800, 1200, 1600]
for breaker_enabled in [False, True]:
for rps in rps_values:
env = simpy.Environment()
db = DB(env, db_cpus, connect_cost)
apps = [App(env, db, pool_size, net_latency, request_timeout, breaker_enabled, trip_threshold, max_trip_duration) for i in range(num_instances)]
env.process(generate_load(env, apps, rps, query_cost))
env.process(track_stats(env, db, apps, rps, all_rows, interval, breaker_enabled))
env.process(nudge(env, db, nudge_delay, db_cpus, nudge_duration, interval))
env.run(until=sim_duration)
df = pandas.DataFrame(data=all_rows, columns=['ts', 'target_tput', 'tput', 'goodput', 'aborts', 'avg_latency', 'db_cpu', 'breaker_enabled'])
df2 = df[df['breaker_enabled']].pivot(index='ts', columns='target_tput', values='db_cpu')
df3 = df[df['breaker_enabled']==False].pivot(index='ts', columns='target_tput', values='db_cpu')
cols = df2.columns
fig = plt.figure(figsize=(12,8))
# fig.tight_layout()
axes = fig.subplots(nrows=len(cols), ncols=2, sharex=True, sharey=True)
for i in range(len(cols)):
axes[i][0].axvspan(nudge_delay, nudge_delay + nudge_duration, alpha=0.1, color='red')
axes[i][0].set_xlim(0, sim_duration)
axes[i][0].set_ylim(0, 110)
axes[i][0].set_title(f'Baseline @ {cols[i]} RPS')
axes[i][0].set_ylabel('DB CPU %')
axes[i][0].plot(df3[cols[i]]*100)
axes[i][1].axvspan(nudge_delay, nudge_delay + nudge_duration, alpha=0.1, color='red')
axes[i][1].set_xlim(0, sim_duration)
axes[i][1].set_ylim(0, 110)
axes[i][1].set_title(f'With circuit breaker @ {cols[i]} RPS')
axes[i][1].plot(df2[cols[i]]*100)
axes[2][0].set_xlabel('Elapsed time (s)')
axes[2][1].set_xlabel('Elapsed time (s)')
plt.savefig('out.png', dpi=300)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment