Last active
December 22, 2023 21:36
-
-
Save rsiemens/499f04e61e35075dc4c641b267380b94 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import argparse | |
import time | |
from copy import copy | |
from threading import Lock, Thread, local | |
from typing import Iterator, Literal | |
DB = { | |
"A": 10, | |
"B": 12, | |
"C": 6, | |
} | |
class Deadlock(Exception): | |
pass | |
class LockManager: | |
def __init__(self): | |
self._latch = Lock() | |
self.locks: dict[str, tuple[Literal["S", "X"], set[Tx]]] = {} | |
self.wait_for_graph: dict[Tx, set[Tx]] = {} | |
def acquire(self, tx: Tx, lock_name: str, mode: Literal["S", "X"]) -> bool: | |
with self._latch: | |
lock = self.locks.get(lock_name) | |
granted = False | |
# no lock for obj yet, create it | |
if lock is None: | |
lock = (mode, {tx}) | |
self.locks[lock_name] = lock | |
granted = True | |
# S request on a S lock is valid | |
elif lock[0] == "S" and mode == "S": | |
lock = ("S", {tx} | lock[1]) | |
self.locks[lock_name] = lock | |
granted = True | |
# upgrade from S to X lock | |
elif lock[0] == "S" and mode == "X" and len(lock[1]) == 1 and tx in lock[1]: | |
lock = ("X", {tx}) | |
self.locks[lock_name] = lock | |
granted = True | |
# request S lock on X lock and tx already owns X lock, grant, but don't downgrade | |
elif lock[0] == "X" and lock[1] == {tx}: | |
granted = True | |
if granted: | |
if tx in self.wait_for_graph: | |
del self.wait_for_graph[tx] | |
return True | |
# can't acquire lock add to wait for graph | |
self.wait_for_graph[tx] = lock[1] | |
# check for deadlock | |
cycle = self.cycle() | |
if cycle: | |
victim = sorted(cycle, key=lambda tx: tx.monotonic_time, reverse=True)[ | |
0 | |
] | |
# kill the yongest (higher time) in the cycle | |
if tx == victim: | |
raise Deadlock(", ".join(str(i) for i in cycle)) | |
# waiting for release | |
return False | |
def release(self, tx: Tx, lock_name: str): | |
with self._latch: | |
lock = self.locks.get(lock_name) | |
if lock is None or tx not in lock[1]: | |
raise Exception( | |
f"You can't release that which you don't own ({tx} tried to release {lock_name})!" | |
) | |
elif lock[0] == "S": | |
if len(lock[1]) > 1: | |
# remove tx from the S lock set | |
self.locks[lock_name] = (lock[0], lock[1] - {tx}) | |
else: | |
# free the lock | |
del self.locks[lock_name] | |
else: # write lock | |
del self.locks[lock_name] | |
if tx in self.wait_for_graph: | |
del self.wait_for_graph[tx] | |
def cycle(self) -> list[Tx]: | |
self.visited: set[Tx] = set() | |
self.stack: list[Tx] = [] | |
self._cycle: list[Tx] = [] | |
for tx in self.wait_for_graph: | |
if tx not in self.visited: | |
self.dfs(tx) | |
return self._cycle | |
def dfs(self, tx: Tx) -> None: | |
self.stack.append(tx) | |
self.visited.add(tx) | |
for dep_tx in self.wait_for_graph.get(tx, set()): | |
if self._cycle: | |
return | |
elif dep_tx not in self.visited: | |
self.dfs(dep_tx) | |
elif dep_tx in self.stack: | |
for circular in self.wait_for_graph[dep_tx]: | |
if circular in self.stack: | |
self._cycle = self.stack[self.stack.index(circular) :] | |
return | |
self.stack.pop() | |
class TxManager: | |
def __init__(self, schedule: list[Tx], serializable: bool): | |
self.schedule = schedule | |
self.serializable = serializable | |
self.lock_manager = LockManager() | |
self.tlocal = local() | |
def print_op(self, tx, op): | |
print(f"{tx}: {' ' * (tx.monotonic_time - 1)}{op}") | |
def execute(self, tx: Tx) -> None: | |
self.tlocal.granted = set() | |
snapshot = copy(DB) | |
rollback: set[str] = set() | |
def do_rollback(tx): | |
for v in rollback: | |
DB[v] = snapshot[v] | |
self.release(tx) | |
try: | |
for op in tx: | |
match op: | |
case ["BEGIN"]: | |
self.print_op(tx, "BEGIN") | |
case ["R", v]: | |
self.s_lock(tx, v) | |
self.print_op(tx, f"R({v}) -> {DB[v]}") | |
case ["W", v, x]: | |
self.x_lock(tx, v) | |
DB[v] = x | |
rollback.add(v) | |
self.print_op(tx, f"W({v}) <- {x}") | |
case ["COMMIT"]: | |
self.release(tx) | |
self.print_op(tx, "COMMIT") | |
case ["ABORT"]: | |
do_rollback(tx) | |
self.print_op(tx, "ABORT") | |
case ["SLEEP", v]: | |
time.sleep(v) | |
except Deadlock as e: | |
# abort | |
do_rollback(tx) | |
self.print_op(tx, f"ABORT (Deadlock victim!)") | |
def s_lock(self, tx: Tx, v: str) -> None: | |
if not self.serializable: | |
return | |
while not self.lock_manager.acquire(tx, v, mode="S"): | |
time.sleep(0.1) | |
self.tlocal.granted.add(v) | |
def x_lock(self, tx: Tx, v: str) -> None: | |
if not self.serializable: | |
return | |
while not self.lock_manager.acquire(tx, v, mode="X"): | |
time.sleep(0.1) | |
self.tlocal.granted.add(v) | |
def release(self, tx: Tx) -> None: | |
if not self.serializable: | |
return | |
for lock in self.tlocal.granted: | |
self.lock_manager.release(tx, lock) | |
def run_schedule(self) -> None: | |
print(f"Start: A={DB['A']} B={DB['B']} C={DB['C']}") | |
threads = [] | |
for tx in self.schedule: | |
t = Thread(target=self.execute, args=[tx]) | |
threads.append(t) | |
t.start() | |
for t in threads: | |
t.join() | |
print(f"End: A={DB['A']} B={DB['B']} C={DB['C']}") | |
class Tx: | |
def __init__(self, monotonic_time: int, operations: list[tuple]): | |
self.monotonic_time = monotonic_time | |
self.operations = operations | |
def __iter__(self) -> Iterator[tuple]: | |
return iter(self.operations) | |
def __str__(self) -> str: | |
return f"Tx{self.monotonic_time}" | |
def __repr__(self) -> str: | |
return self.__str__() | |
class TxStream(Tx): | |
def __init__(self, monotonic_time: int): | |
self.monotonic_time = monotonic_time | |
self._began = False | |
self._ended = False | |
def __iter__(self) -> Iterator[tuple]: | |
return self | |
def __next__(self): | |
if self._ended: | |
raise StopIteration("Tx closed") | |
val = input("> ").split(" ") | |
if not self._began and val[0] != "BEGIN": | |
raise StopIteration('Tx must start with "BEGIN"') | |
elif not self._began: | |
self._began = True | |
if not self._ended and val[0] in {"COMMIT", "ABORT"}: | |
self._ended = True | |
return tuple(val) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--2pl", | |
action="store_true", | |
dest="two_phase_locking", | |
help="make things serializable with 2 phase locking", | |
) | |
parser.add_argument( | |
"--conflict", | |
action="store", | |
type=str, | |
required=False, | |
choices=["R-W", "W-R", "W-W", "Deadlock"], | |
help="The conflict type that can occur", | |
) | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
schedules = { | |
"R-W": ( | |
"R-W conflict (unrepetable-read)", | |
[ | |
Tx( | |
1, [("BEGIN",), ("R", "A"), ("SLEEP", 0.2), ("R", "A"), ("COMMIT",)] | |
), | |
Tx(2, [("BEGIN",), ("W", "A", 5), ("COMMIT",)]), | |
], | |
), | |
"W-R": ( | |
"W-R conflict (dirty-read)", | |
[ | |
Tx(1, [("BEGIN",), ("W", "B", 2), ("SLEEP", 0.2), ("COMMIT",)]), | |
Tx(2, [("BEGIN",), ("SLEEP", 0.1), ("R", "B"), ("COMMIT",)]), | |
], | |
), | |
"W-W": ( | |
"W-W conflict (lost-update)", | |
[ | |
Tx( | |
1, | |
[ | |
("BEGIN",), | |
("R", "A"), | |
("W", "A", 4), | |
("SLEEP", 0.2), | |
("COMMIT",), | |
], | |
), | |
Tx(2, [("BEGIN",), ("R", "A"), ("W", "A", 5), ("COMMIT",)]), | |
], | |
), | |
"Deadlock": ( | |
"Deadlock", | |
[ | |
Tx( | |
1, | |
[ | |
("BEGIN",), | |
("W", "A", 4), | |
("SLEEP", 0.2), | |
("R", "B"), | |
("COMMIT",), | |
], | |
), | |
Tx( | |
2, | |
[ | |
("BEGIN",), | |
("W", "B", 5), | |
("SLEEP", 0.2), | |
("R", "C"), | |
("COMMIT",), | |
], | |
), | |
Tx( | |
3, | |
[ | |
("BEGIN",), | |
("W", "C", 5), | |
("SLEEP", 0.2), | |
("R", "A"), | |
("COMMIT",), | |
], | |
), | |
], | |
), | |
} | |
if not args.conflict: | |
print(f"Live mode (2PL={args.two_phase_locking})") | |
TxManager([TxStream(1)], serializable=args.two_phase_locking).run_schedule() | |
return | |
if args.conflict == "Deadlock" and not args.two_phase_locking: | |
print("Deadlock detection is only valid with --2pl") | |
return | |
title, schedule = schedules[args.conflict] | |
if args.two_phase_locking: | |
if args.conflict == "Deadlock": | |
print(f"{title} detection with 2PL\n") | |
else: | |
print(f"{title} avoided with 2PL\n") | |
TxManager(schedule, serializable=True).run_schedule() | |
else: | |
print(f"{title} occurs without 2PL\n") | |
TxManager(schedule, serializable=False).run_schedule() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment