Skip to content

Instantly share code, notes, and snippets.

@CalebMuhia
Forked from mrcfps/bank_test.py
Created June 13, 2022 17:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CalebMuhia/ff17a151dce3ac3c6a6526399e12d006 to your computer and use it in GitHub Desktop.
Save CalebMuhia/ff17a151dce3ac3c6a6526399e12d006 to your computer and use it in GitHub Desktop.
Exploration of how to test concurrency in Python
from __future__ import print_function
import platform
import sys
import threading
import time
class UnsyncedBankAccount(object):
"""Bank account with no synchronization lock, prone to race condition."""
def __init__(self):
self.is_open = False
self.balance = 0
def get_balance(self):
if self.is_open:
return self.balance
else:
raise ValueError
def open(self):
self.is_open = True
def deposit(self, amount):
if self.is_open and amount > 0:
self.balance += amount
else:
raise ValueError
def withdraw(self, amount):
if self.is_open and 0 < amount <= self.balance:
self.balance -= amount
else:
raise ValueError
def close(self):
self.is_open = False
class SyncedBankAccount(object):
"""Bank account with synchronization strategy, thread-safe."""
def __init__(self):
self.is_open = False
self.balance = 0
self.lock = threading.Lock()
def get_balance(self):
with self.lock:
if self.is_open:
return self.balance
else:
raise ValueError
def open(self):
self.is_open = True
def deposit(self, amount):
with self.lock:
if self.is_open and amount > 0:
self.balance += amount
else:
raise ValueError
def withdraw(self, amount):
with self.lock:
if self.is_open and 0 < amount <= self.balance:
self.balance -= amount
else:
raise ValueError
def close(self):
self.is_open = False
def adjust_balance_concurrently(account):
def transact():
account.deposit(5)
time.sleep(0.001)
account.withdraw(5)
# Greatly improve the chance of an operation being interrupted
# by thread switch, thus testing synchronization effectively.
# Feel free to tweak the parameters below to see their impact.
try:
sys.setswitchinterval(1e-12)
except AttributeError:
# Python 2 compatible
sys.setcheckinterval(1)
threads = []
for _ in range(1000):
t = threading.Thread(target=transact)
threads.append(t)
t.start()
for thread in threads:
thread.join()
if __name__ == '__main__':
# Initialization
unsync_account = UnsyncedBankAccount()
unsync_account.open()
unsync_account.deposit(1000)
sync_account = SyncedBankAccount()
sync_account.open()
sync_account.deposit(1000)
# Test unsynced bank account
for _ in range(10):
adjust_balance_concurrently(unsync_account)
# Test synced bank account
for _ in range(10):
adjust_balance_concurrently(sync_account)
# Report results
print("Python version: {}\n".format(platform.python_version()))
print("Balance of unsynced account after concurrent transactions:")
print("{}. Expected: 1000\n".format(unsync_account.balance))
print("Balance of synced account after concurrent transactions:")
print("{}. Expected: 1000\n".format(sync_account.balance))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment