Skip to content

Instantly share code, notes, and snippets.

@rsiemens
Last active December 22, 2023 21:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rsiemens/499f04e61e35075dc4c641b267380b94 to your computer and use it in GitHub Desktop.
Save rsiemens/499f04e61e35075dc4c641b267380b94 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import argparse
import time
from copy import copy
from threading import Lock, Thread, local
from typing import Iterator, Literal
DB = {
"A": 10,
"B": 12,
"C": 6,
}
class Deadlock(Exception):
pass
class LockManager:
def __init__(self):
self._latch = Lock()
self.locks: dict[str, tuple[Literal["S", "X"], set[Tx]]] = {}
self.wait_for_graph: dict[Tx, set[Tx]] = {}
def acquire(self, tx: Tx, lock_name: str, mode: Literal["S", "X"]) -> bool:
with self._latch:
lock = self.locks.get(lock_name)
granted = False
# no lock for obj yet, create it
if lock is None:
lock = (mode, {tx})
self.locks[lock_name] = lock
granted = True
# S request on a S lock is valid
elif lock[0] == "S" and mode == "S":
lock = ("S", {tx} | lock[1])
self.locks[lock_name] = lock
granted = True
# upgrade from S to X lock
elif lock[0] == "S" and mode == "X" and len(lock[1]) == 1 and tx in lock[1]:
lock = ("X", {tx})
self.locks[lock_name] = lock
granted = True
# request S lock on X lock and tx already owns X lock, grant, but don't downgrade
elif lock[0] == "X" and lock[1] == {tx}:
granted = True
if granted:
if tx in self.wait_for_graph:
del self.wait_for_graph[tx]
return True
# can't acquire lock add to wait for graph
self.wait_for_graph[tx] = lock[1]
# check for deadlock
cycle = self.cycle()
if cycle:
victim = sorted(cycle, key=lambda tx: tx.monotonic_time, reverse=True)[
0
]
# kill the yongest (higher time) in the cycle
if tx == victim:
raise Deadlock(", ".join(str(i) for i in cycle))
# waiting for release
return False
def release(self, tx: Tx, lock_name: str):
with self._latch:
lock = self.locks.get(lock_name)
if lock is None or tx not in lock[1]:
raise Exception(
f"You can't release that which you don't own ({tx} tried to release {lock_name})!"
)
elif lock[0] == "S":
if len(lock[1]) > 1:
# remove tx from the S lock set
self.locks[lock_name] = (lock[0], lock[1] - {tx})
else:
# free the lock
del self.locks[lock_name]
else: # write lock
del self.locks[lock_name]
if tx in self.wait_for_graph:
del self.wait_for_graph[tx]
def cycle(self) -> list[Tx]:
self.visited: set[Tx] = set()
self.stack: list[Tx] = []
self._cycle: list[Tx] = []
for tx in self.wait_for_graph:
if tx not in self.visited:
self.dfs(tx)
return self._cycle
def dfs(self, tx: Tx) -> None:
self.stack.append(tx)
self.visited.add(tx)
for dep_tx in self.wait_for_graph.get(tx, set()):
if self._cycle:
return
elif dep_tx not in self.visited:
self.dfs(dep_tx)
elif dep_tx in self.stack:
for circular in self.wait_for_graph[dep_tx]:
if circular in self.stack:
self._cycle = self.stack[self.stack.index(circular) :]
return
self.stack.pop()
class TxManager:
def __init__(self, schedule: list[Tx], serializable: bool):
self.schedule = schedule
self.serializable = serializable
self.lock_manager = LockManager()
self.tlocal = local()
def print_op(self, tx, op):
print(f"{tx}: {' ' * (tx.monotonic_time - 1)}{op}")
def execute(self, tx: Tx) -> None:
self.tlocal.granted = set()
snapshot = copy(DB)
rollback: set[str] = set()
def do_rollback(tx):
for v in rollback:
DB[v] = snapshot[v]
self.release(tx)
try:
for op in tx:
match op:
case ["BEGIN"]:
self.print_op(tx, "BEGIN")
case ["R", v]:
self.s_lock(tx, v)
self.print_op(tx, f"R({v}) -> {DB[v]}")
case ["W", v, x]:
self.x_lock(tx, v)
DB[v] = x
rollback.add(v)
self.print_op(tx, f"W({v}) <- {x}")
case ["COMMIT"]:
self.release(tx)
self.print_op(tx, "COMMIT")
case ["ABORT"]:
do_rollback(tx)
self.print_op(tx, "ABORT")
case ["SLEEP", v]:
time.sleep(v)
except Deadlock as e:
# abort
do_rollback(tx)
self.print_op(tx, f"ABORT (Deadlock victim!)")
def s_lock(self, tx: Tx, v: str) -> None:
if not self.serializable:
return
while not self.lock_manager.acquire(tx, v, mode="S"):
time.sleep(0.1)
self.tlocal.granted.add(v)
def x_lock(self, tx: Tx, v: str) -> None:
if not self.serializable:
return
while not self.lock_manager.acquire(tx, v, mode="X"):
time.sleep(0.1)
self.tlocal.granted.add(v)
def release(self, tx: Tx) -> None:
if not self.serializable:
return
for lock in self.tlocal.granted:
self.lock_manager.release(tx, lock)
def run_schedule(self) -> None:
print(f"Start: A={DB['A']} B={DB['B']} C={DB['C']}")
threads = []
for tx in self.schedule:
t = Thread(target=self.execute, args=[tx])
threads.append(t)
t.start()
for t in threads:
t.join()
print(f"End: A={DB['A']} B={DB['B']} C={DB['C']}")
class Tx:
def __init__(self, monotonic_time: int, operations: list[tuple]):
self.monotonic_time = monotonic_time
self.operations = operations
def __iter__(self) -> Iterator[tuple]:
return iter(self.operations)
def __str__(self) -> str:
return f"Tx{self.monotonic_time}"
def __repr__(self) -> str:
return self.__str__()
class TxStream(Tx):
def __init__(self, monotonic_time: int):
self.monotonic_time = monotonic_time
self._began = False
self._ended = False
def __iter__(self) -> Iterator[tuple]:
return self
def __next__(self):
if self._ended:
raise StopIteration("Tx closed")
val = input("> ").split(" ")
if not self._began and val[0] != "BEGIN":
raise StopIteration('Tx must start with "BEGIN"')
elif not self._began:
self._began = True
if not self._ended and val[0] in {"COMMIT", "ABORT"}:
self._ended = True
return tuple(val)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--2pl",
action="store_true",
dest="two_phase_locking",
help="make things serializable with 2 phase locking",
)
parser.add_argument(
"--conflict",
action="store",
type=str,
required=False,
choices=["R-W", "W-R", "W-W", "Deadlock"],
help="The conflict type that can occur",
)
return parser.parse_args()
def main():
args = parse_args()
schedules = {
"R-W": (
"R-W conflict (unrepetable-read)",
[
Tx(
1, [("BEGIN",), ("R", "A"), ("SLEEP", 0.2), ("R", "A"), ("COMMIT",)]
),
Tx(2, [("BEGIN",), ("W", "A", 5), ("COMMIT",)]),
],
),
"W-R": (
"W-R conflict (dirty-read)",
[
Tx(1, [("BEGIN",), ("W", "B", 2), ("SLEEP", 0.2), ("COMMIT",)]),
Tx(2, [("BEGIN",), ("SLEEP", 0.1), ("R", "B"), ("COMMIT",)]),
],
),
"W-W": (
"W-W conflict (lost-update)",
[
Tx(
1,
[
("BEGIN",),
("R", "A"),
("W", "A", 4),
("SLEEP", 0.2),
("COMMIT",),
],
),
Tx(2, [("BEGIN",), ("R", "A"), ("W", "A", 5), ("COMMIT",)]),
],
),
"Deadlock": (
"Deadlock",
[
Tx(
1,
[
("BEGIN",),
("W", "A", 4),
("SLEEP", 0.2),
("R", "B"),
("COMMIT",),
],
),
Tx(
2,
[
("BEGIN",),
("W", "B", 5),
("SLEEP", 0.2),
("R", "C"),
("COMMIT",),
],
),
Tx(
3,
[
("BEGIN",),
("W", "C", 5),
("SLEEP", 0.2),
("R", "A"),
("COMMIT",),
],
),
],
),
}
if not args.conflict:
print(f"Live mode (2PL={args.two_phase_locking})")
TxManager([TxStream(1)], serializable=args.two_phase_locking).run_schedule()
return
if args.conflict == "Deadlock" and not args.two_phase_locking:
print("Deadlock detection is only valid with --2pl")
return
title, schedule = schedules[args.conflict]
if args.two_phase_locking:
if args.conflict == "Deadlock":
print(f"{title} detection with 2PL\n")
else:
print(f"{title} avoided with 2PL\n")
TxManager(schedule, serializable=True).run_schedule()
else:
print(f"{title} occurs without 2PL\n")
TxManager(schedule, serializable=False).run_schedule()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment