Skip to content

Instantly share code, notes, and snippets.

@elliotchance
Created February 19, 2017 08:23
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save elliotchance/21e31b8ffb18cbbca23b8031639e1c3f to your computer and use it in GitHub Desktop.
Save elliotchance/21e31b8ffb18cbbca23b8031639e1c3f to your computer and use it in GitHub Desktop.
Implementing all four SQL transaction isolation levels in Python
# -*- coding: utf8 -*-
from __future__ import print_function
class LockManager:
def __init__(self):
self.locks = []
def add(self, transaction, record_id):
if not self.exists(transaction, record_id):
self.locks.append([transaction, record_id])
def exists(self, transaction, record_id):
return any(lock[0] is transaction and lock[1] == record_id \
for lock in self.locks)
class Table:
def __init__(self):
self.next_xid = 1
self.active_xids = set()
self.records = []
self.locks = LockManager()
def new_transaction(self, transaction_type):
self.next_xid += 1
self.active_xids.add(self.next_xid)
return transaction_type(self, self.next_xid)
class RollbackException(Exception):
pass
class Transaction:
def __init__(self, table, xid):
self.table = table
self.xid = xid
self.rollback_actions = []
def add_record(self, id, name):
record = {
'id': id,
'name': name,
'created_xid': self.xid,
'expired_xid': 0
}
self.rollback_actions.append(["delete", len(self.table.records)])
self.table.records.append(record)
def delete_record(self, id):
for i, record in enumerate(self.table.records):
if self.record_is_visible(record) and record['id'] == id:
if self.record_is_locked(record):
raise RollbackException("Row locked by another transaction.")
else:
record['expired_xid'] = self.xid
self.rollback_actions.append(["add", i])
def update_record(self, id, name):
self.delete_record(id)
self.add_record(id, name)
def fetch_record(self, id):
for record in self.table.records:
if self.record_is_visible(record) and record['id'] is id:
return record
return None
def count_records(self, min_id, max_id):
count = 0
for record in self.table.records:
if self.record_is_visible(record) and \
min_id <= record['id'] <= max_id:
count += 1
return count
def fetch_all_records(self):
visible_records = []
for record in self.table.records:
if self.record_is_visible(record):
visible_records.append(record)
return visible_records
def fetch(self, expr):
visible_records = []
for record in self.table.records:
if self.record_is_visible(record) and expr(record):
visible_records.append(record)
return visible_records
def commit(self):
self.table.active_xids.discard(self.xid)
def rollback(self):
for action in reversed(self.rollback_actions):
if action[0] == 'add':
self.table.records[action[1]]['expired_xid'] = 0
elif action[0] == 'delete':
self.table.records[action[1]]['expired_xid'] = self.xid
self.table.active_xids.discard(self.xid)
class ReadUncommittedTransaction(Transaction):
def record_is_locked(self, record):
return record['expired_xid'] != 0
def record_is_visible(self, record):
return record['expired_xid'] == 0
class ReadCommittedTransaction(Transaction):
def record_is_locked(self, record):
return record['expired_xid'] != 0 and \
row['expired_xid'] in self.table.active_xids
def record_is_visible(self, record):
# The record was created in active transaction that is not our
# own.
if record['created_xid'] in self.table.active_xids and \
record['created_xid'] != self.xid:
return False
# The record is expired or and no transaction holds it that is
# our own.
if record['expired_xid'] != 0 and \
(record['expired_xid'] not in self.table.active_xids or \
record['expired_xid'] == self.xid):
return False
return True
class RepeatableReadTransaction(ReadCommittedTransaction):
def record_is_locked(self, record):
return ReadCommittedTransaction.record_is_locked(self, record) or \
self.table.locks.exists(self, record['id'])
def record_is_visible(self, record):
is_visible = ReadCommittedTransaction.record_is_visible(self, record)
if is_visible:
self.table.locks.add(self, record['id'])
return is_visible
class SerializableTransaction(RepeatableReadTransaction):
def __init__(self, table, xid):
Transaction.__init__(self, table, xid)
self.existing_xids = self.table.active_xids.copy()
def record_is_visible(self, record):
is_visible = ReadCommittedTransaction.record_is_visible(self, record) \
and record['created_xid'] <= self.xid \
and record['created_xid'] in self.existing_xids
if is_visible:
self.table.locks.add(self, record['id'])
return is_visible
class TransactionTest:
def __init__(self, transaction_type):
self.table = Table()
client = self.table.new_transaction(ReadCommittedTransaction)
client.add_record(id=1, name="Joe")
client.add_record(id=3, name="Jill")
client.commit()
self.client1 = self.table.new_transaction(transaction_type)
self.client2 = self.table.new_transaction(transaction_type)
def run_test(self):
try:
return self.run()
except RollbackException:
return False
def result(self):
if self.run_test():
return u'✔'
return u'✘'
class DirtyRead(TransactionTest):
def run(self):
result1 = self.client1.fetch_record(id=1)
self.client2.update_record(id=1, name="Joe 2")
result2 = self.client1.fetch_record(id=1)
return result1 != result2
class NonRepeatableRead(TransactionTest):
def run(self):
result1 = self.client1.fetch_record(id=1)
self.client2.update_record(id=1, name="Joe 2")
self.client2.commit()
result2 = self.client1.fetch_record(id=1)
return result1 != result2
class PhantomRead(TransactionTest):
def run(self):
result1 = len(self.client1.fetch(lambda r: 1 <= r['id'] <= 3))
self.client2.add_record(id=2, name="John")
self.client2.commit()
result2 = self.client1.count_records(min_id=1, max_id=3)
return result1 != result2
isolation_modes = [
['read uncommitted', ReadUncommittedTransaction],
['read committed ', ReadCommittedTransaction],
['repeatable read ', RepeatableReadTransaction],
['serializable ', SerializableTransaction]
]
possible_errors = [DirtyRead, NonRepeatableRead, PhantomRead]
print(' Dirty Repeat Phantom')
for isolation_mode in isolation_modes:
results = [possible_error(isolation_mode[1]).result() for possible_error in possible_errors]
print(isolation_mode[0] + " " + results[0] + " " + results[1] + " " + results[2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment