Skip to content

Instantly share code, notes, and snippets.

@ptmcg
Last active November 19, 2020 05:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ptmcg/eae1eee566e8bacb948fc2d2d565a7a5 to your computer and use it in GitHub Desktop.
Save ptmcg/eae1eee566e8bacb948fc2d2d565a7a5 to your computer and use it in GitHub Desktop.
RCU class for read-copy-update synchronization to a shared resource
#
# rcu.py
#
# Paul McGuire - November, 2020
#
from contextlib import contextmanager
import copy
import threading
class RcuSynchronizer:
"""
Class to implement Read-Copy-Update synchronization of a shared resource, in which
multiple readers can access the resource without requiring any locking, while
updaters may update the resource by solely synchronizing among each other (because
they work with a copy of the shared object instead of modifying it in-place).
Parameters:
managed_object: object to be accessed using RCU synchronization
copy_fn: (optional) method to make a copy of the managed object, with method
signature:
copy_function(obj: T) -> T:
if not provided, copy.copy is used
Readers may access the shared value by calling rcu.get().
Writers access the shared value using a context manager returned by
calling rcu.updating():
with rcu.updating() as shared_updater:
value_copy = shared_updater.update_value
... code that modifies value_copy ...
# if the shared object is mutable, then no additional
# code is needed
# if the shared object is immutable, then the writer must
# write back into the context manager
shared_updater.update_value = value_copy
"""
def __init__(self, managed_object, copy_fn=None):
self._shared = [managed_object]
self._copy_fn = copy_fn if copy_fn is not None else copy.copy
self._update_lock = threading.Lock()
def read(self):
return self._shared[0]
get = read
@contextmanager
def updating(self):
with self._update_lock:
try:
self.update_value = self._get_copy()
yield self
finally:
self._shared[0] = self.update_value
del self.update_value
def _get_copy(self):
return self._copy_fn(self._shared[0])
def _update(self, obj):
# internal method to do synchronized update to shared object
with self._update_lock:
self._shared[0] = obj
if __name__ == '__main__':
#
# demo code
#
from collections import deque
from random import randint, randrange, shuffle
import time
shared = RcuSynchronizer([randint(1,10000) for _ in range(10)])
shared_history = deque(maxlen=50)
shared_history.append(sum(shared.get()))
busted = False
def do_read():
global busted
# get the current shared value of the list
shared_list = shared.get()
print("reading", flush=True)
# compute sum slowly, giving updaters a chance to modify the list, but our copy
# will remain unchanged - so we should always have a consistent snapshot for the life
# of our access to the list, without having to make a copy
slow_sum = 0
for i in shared_list:
time.sleep(0.005)
slow_sum += i
# call attention to instances where slow_sum is not the same as the sum of any
# of the historical values of the shared value (if we got an inconsistent snapshot,
# then it shouldn't match up with any of the historical lists)
if slow_sum not in shared_history:
print("\n>>>>>>>>>>>>BUST!!!<<<<<<<<<<<<<<<<<<<<"*3, flush=True)
print("\ngot {}, expected one of {}, (using list {})".format(slow_sum, historical_sums, shared_list))
busted = True
def do_update():
print("updating", flush=True)
# modify the shared resource
with shared.updating() as shared_updater:
shared_list = shared_updater.update_value
# make up to 5 changes in the list
for _ in range(randint(1, 5)):
if randint(1, 20) < len(shared_list) / 2:
shared_list.pop(randrange(len(shared_list)))
else:
shared_list.append(randint(1, 10000))
# update shared_history for reader validation
shared_history.append(sum(shared_list))
# show the modified resource
print(shared.get(), flush=True)
iters = 500
updating = True
def reader():
"""
target method for threaded readers
"""
while updating:
do_read()
def writer():
"""
target method for threaded updaters
"""
for i in range(iters):
do_update()
time.sleep(0.2)
readers = [threading.Thread(target=reader) for _ in range(25)]
writers = [threading.Thread(target=writer) for _ in range(5)]
# kick off threads - after shuffling readers and writers
all_threads = readers + writers
shuffle(all_threads)
for t in all_threads:
t.start()
# wait for all writers to finish updates
for t in writers:
t.join()
# clear flag telling readers that updating is still in progress, so
# they can stop reading
updating = False
for t in readers:
t.join()
# any failures?
print(f"{busted=}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment