Skip to content

Instantly share code, notes, and snippets.

@GeoffChurch
Created May 7, 2024 01:53
Show Gist options
  • Save GeoffChurch/f9a7f41b10ec5fa7916d531e49a54d9f to your computer and use it in GitHub Desktop.
Save GeoffChurch/f9a7f41b10ec5fa7916d531e49a54d9f to your computer and use it in GitHub Desktop.
Speculative execution on both branches of a conditional
"""
Provides speculative_if(cond, branch1, branch2), which runs branch1 followed by
branch2 while cond is running.
"""
from dataclasses import dataclass
import itertools
import multiprocessing as mp
import time
import matplotlib.pyplot as plt
def enqueue_ret(key, f, q):
"""
Wraps the function f to return its result through q,
keyed by key.
"""
q.put((key, f()))
def get_result(q : mp.Queue, key):
"""
Blocks until a result tuple with first element equal to key is found in the queue.
The elements encountered before the result are put back into the queue at the end,
and will be encountered in reverse order. This is an effect of optimizing to
avoid any unnecessary copying and/or scanning. A deque could match the asymptotic
complexity while only inducing a rotation of the elements, at the cost of overhead.
"""
backlog = []
while True:
key_, val = q.get()
if key_ == key:
while backlog:
q.put(backlog.pop())
return val
backlog.append((key_, val))
def speculative_if(cond, branch1, branch2):
"""
Runs branch1 followed by branch2 while cond is running,
and returns the result as soon as it and cond are ready.
"""
idcond, id1, id2 = range(3)
# all results will be put into this queue
q = mp.Queue()
# start cond_p
cond_p = mp.Process(target=enqueue_ret, args=(idcond, cond, q))
cond_p.start()
# start branch1_p
branch1_p = mp.Process(target=enqueue_ret, args=(id1, branch1, q))
branch1_p.start()
# wait for either of cond_p or branch1_p to finish
funcname, ret = q.get()
if funcname == idcond:
assert isinstance(ret, bool)
if ret:
return get_result(q, id1)
else:
# The following has another possible implementation,
# where we send the termination signal to branch1,
# asynchronously start branch2, and then join on both.
# That would be better in case branch1 is slow to terminate.
# forcibly stop branch1_p
branch1_p.terminate()
branch1_p.join()
# run branch2 in the main process
return branch2()
else:
assert funcname == id1
# start branch2
branch2_p = mp.Process(target=enqueue_ret, args=(id2, branch2, q))
branch2_p.start()
# wait for cond_p to finish
cond_ret = get_result(q, idcond)
assert isinstance(cond_ret, bool)
if cond_ret:
# forcibly stop branch2_p
branch2_p.terminate()
branch2_p.join()
return ret
else:
return get_result(q, id2)
class SlowF:
"""
A function object that sleeps for sleep_time seconds and then returns ret.
A class is needed because closures are not pickleable.
"""
def __init__(self, sleep_time, ret):
self.sleep_time = sleep_time
self.ret = ret
def __call__(self):
time.sleep(self.sleep_time)
return self.ret
def get_expected_time(cond_t, t1, t2, which):
"""
Returns the ideal expected time for the speculative_if function,
assuming zero overhead from spawning, switching, and signalling
processes.
"""
if which:
return max(cond_t, t1) # cond and 1 run in parallel
else:
if cond_t <= t1:
return cond_t + t2 # 2 starts right after cond
else:
t2_end = t1 + t2 # 2 starts right after 1
return max(cond_t, t2_end) # but might have to wait for cond
@dataclass
class TestResult:
cond_t: float
t1: float
t2: float
which: bool
actual_time: float
@property
def expected_time(self):
return get_expected_time(self.cond_t, self.t1, self.t2, self.which)
@property
def diff(self):
return self.actual_time - self.expected_time
def __str__(self):
return f"cond_t={self.cond_t}, t1={self.t1}, t2={self.t2}, which={'t1' if self.which else 't2'}, expected_time={self.expected_time:1.2f}, actual_time={self.actual_time:1.2f}, diff={self.diff:1.2f}"
def get_test_results():
method_start_time = time.time()
rets = []
for cond_t, t1, t2 in itertools.product(range(4), repeat=3):
# cond_t, t1, t2 = 3 * cond_t, 3 * t1, 3 * t2
for which in [True, False]:
cond = SlowF(cond_t, which)
branch1 = SlowF(t1, 1)
branch2 = SlowF(t2, 2)
start_time = time.time()
ret = speculative_if(cond, branch1, branch2)
actual_time = time.time() - start_time
assert ret == (1 if which else 2)
rets.append(TestResult(cond_t, t1, t2, which, actual_time))
print(rets[-1])
print(f"get_test_results time: {time.time() - method_start_time}")
return rets
def main():
rets = get_test_results()
print("Sorted by diff:")
increasing_diffs = sorted(rets, key=lambda r: r.diff, reverse=True)
for r in increasing_diffs:
print(r)
expected_times = sorted(set(r.expected_time for r in rets))
plt.plot(expected_times, expected_times, color="red", label="theory")
plt.scatter([r.expected_time for r in rets], [r.actual_time for r in rets], marker="x", label="practice")
# label the worst offender for each expected time
outliers = []
for expected_time in expected_times:
same_expected_time = [r for r in rets if r.expected_time == expected_time]
outlier = max(same_expected_time, key=lambda r: r.diff)
outliers.append(outlier)
for r in outliers:
plt.text(r.expected_time, r.actual_time, f"cond_t={r.cond_t} t1={r.t1} t2={r.t2} which={'t1' if r.which else 't2'} diff={r.diff:1.2f}")
plt.xlabel("Expected time")
plt.ylabel("Actual time")
plt.legend()
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment