Created
December 7, 2017 02:11
-
-
Save bbengfort/936b4b3db9d81d27204a81f6ad816e5d to your computer and use it in GitHub Desktop.
Database transactions blog post.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os | |
import logging | |
import psycopg2 as pg | |
from decimal import Decimal | |
from functools import wraps | |
from psycopg2.pool import ThreadedConnectionPool | |
from contextlib import contextmanager | |
LOG_FORMAT = "%(asctime)s %(message)s" | |
MAX_DEPOSIT_LIMIT = 1000.00 | |
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) | |
log = logging.getLogger('balance') | |
def connect(env="DATABASE_URL", connections=2): | |
""" | |
Connect to the database using an environment variable. | |
""" | |
url = os.getenv(env) | |
if not url: | |
raise ValueError("no database url specified") | |
minconns = connections | |
maxconns = connections * 2 | |
return ThreadedConnectionPool(minconns, maxconns, url) | |
pool = connect() | |
def createdb(conn, schema="schema.sql"): | |
""" | |
Execute DROP and CREATE TABLE statements in the specified SQL file. | |
""" | |
with open(schema, 'r') as f: | |
sql = f.read() | |
try: | |
with conn.cursor() as curs: | |
curs.execute(sql) | |
conn.commit() | |
except Exception as e: | |
conn.rollback() | |
raise e | |
@contextmanager | |
def transaction(name="transaction", **kwargs): | |
# Get the session parameters from the kwargs | |
options = { | |
"isolation_level": kwargs.get("isolation_level", None), | |
"readonly": kwargs.get("readonly", None), | |
"deferrable": kwargs.get("deferrable", None), | |
} | |
try: | |
conn = pool.getconn() | |
conn.set_session(**options) | |
yield conn | |
conn.commit() | |
except Exception as e: | |
conn.rollback() | |
log.error("{} error: {}".format(name, e)) | |
finally: | |
conn.reset() | |
pool.putconn(conn) | |
def transact(func): | |
""" | |
Creates a connection per-transaction, committing when complete or | |
rolling back if there is an exception. It also ensures that the conn is | |
closed when we're done. | |
""" | |
@wraps(func) | |
def inner(*args, **kwargs): | |
with transaction(name=func.__name__) as conn: | |
func(conn, *args, **kwargs) | |
return inner | |
def authenticate(conn, user, pin, account=None): | |
""" | |
Returns an account id if the name is found and if the pin matches. | |
""" | |
with conn.cursor() as curs: | |
sql = "SELECT 1 AS authd FROM users WHERE username=%s AND pin=%s" | |
curs.execute(sql, (user, pin)) | |
if curs.fetchone() is None: | |
raise ValueError("could not validate user via PIN") | |
return True | |
if account: | |
# Verify account ownership if account is provided | |
verify_account(conn, user, account) | |
def verify_account(conn, user, account): | |
""" | |
Verify that the account is held by the user. | |
""" | |
with conn.cursor() as curs: | |
sql = ( | |
"SELECT 1 AS verified FROM accounts a " | |
"JOIN users u on u.id = a.owner_id " | |
"WHERE u.username=%s AND a.id=%s" | |
) | |
curs.execute(sql, (user, account)) | |
if curs.fetchone() is None: | |
raise ValueError("account belonging to user not found") | |
return True | |
def ledger(conn, account, record, amount): | |
""" | |
Add a ledger record with the amount being credited or debited. | |
""" | |
# Perform the insert | |
with conn.cursor() as curs: | |
sql = "INSERT INTO ledger (account_id, type, amount) VALUES (%s, %s, %s)" | |
curs.execute(sql, (account, record, amount)) | |
# If we are crediting the account, perform daily deposit verification | |
if record == "credit": | |
check_daily_deposit(conn, account) | |
def check_daily_deposit(conn, account): | |
""" | |
Raise an exception if the deposit limit has been exceeded. | |
""" | |
with conn.cursor() as curs: | |
sql = ( | |
"SELECT amount FROM ledger " | |
"WHERE date=now()::date AND type='credit' AND account_id=%s" | |
) | |
curs.execute(sql, (account,)) | |
total = sum(row[0] for row in curs.fetchall()) | |
if total > MAX_DEPOSIT_LIMIT: | |
raise Exception("daily deposit limit has been exceeded!") | |
def update_balance(conn, account, amount): | |
""" | |
Add the amount (or subtract if negative) to the account balance. | |
""" | |
amount = Decimal(amount) | |
with conn.cursor() as curs: | |
current = balance(conn, account) | |
sql = "UPDATE accounts SET balance=%s WHERE id=%s" | |
curs.execute(sql, (current+amount, account)) | |
def balance(conn, account): | |
with conn.cursor() as curs: | |
curs.execute("SELECT balance FROM accounts WHERE id=%s", (account,)) | |
return curs.fetchone()[0] | |
@transact | |
def withdraw(conn, user, pin, account, amount): | |
# Step 1: authenticate the user via pin and verify account ownership | |
authenticate(conn, user, pin, account) | |
# Step 2: add the ledger record with the debit | |
ledger(conn, account, "debit", amount) | |
# Step 3: update the account value by subtracting the amount | |
update_balance(conn, account, amount * -1) | |
# Fetch the current balance in the account and log it | |
record = "withdraw ${:0.2f} from account {} | current balance: ${:0.2f}" | |
log.info(record.format(amount, account, balance(conn, account))) | |
@transact | |
def deposit(conn, user, pin, account, amount): | |
# Step 1: authenticate the user via pin and verify account ownership | |
authenticate(conn, user, pin, account) | |
# Step 2: add the ledger record with the credit | |
ledger(conn, account, "credit", amount) | |
# Step 3: update the account value by adding the amount | |
update_balance(conn, account, amount) | |
# Fetch the current balance in the account and log it | |
record = "withdraw ${:0.2f} from account {} | current balance: ${:0.2f}" | |
log.info(record.format(amount, account, balance(conn, account))) | |
def sequential(): | |
# Successful deposit | |
deposit('alice', 1234, 1, 785.0) | |
# Unsuccessful authenticate | |
withdraw('bob', 8881, 2, 180.00) | |
# Successful withdrawal | |
withdraw('alice', 1234, 1, 230.0) | |
# Unsuccessful deposit | |
deposit('alice', 1234, 1, 489.0) | |
# Successful deposit | |
deposit('bob', 9999, 2, 220.23) | |
if __name__ == '__main__': | |
import time | |
import random | |
import threading | |
conn = pool.getconn() | |
createdb(conn) | |
pool.putconn(conn) | |
def op1(): | |
time.sleep(random.random()) | |
withdraw('alice', 1234, 1, 300.0) | |
def op2(): | |
time.sleep(random.random()) | |
deposit('alice', 1234, 1, 75.0) | |
withdraw('alice', 1234, 1, 25.0) | |
threads = [ | |
threading.Thread(target=op1), | |
threading.Thread(target=op2), | |
] | |
for t in threads: | |
t.start() | |
for t in threads: | |
t.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | |
-- Reset the database by dropping existing tables, then create the tables | |
-- required by the balance.py Python script. | |
-- | |
DROP TABLE IF EXISTS users CASCADE; | |
CREATE TABLE users ( | |
id SERIAL PRIMARY KEY, | |
username VARCHAR(255) UNIQUE, | |
pin SMALLINT NOT NULL | |
); | |
DROP TYPE IF EXISTS account_type CASCADE; | |
CREATE TYPE account_type AS ENUM ('checking', 'savings'); | |
DROP TABLE IF EXISTS accounts CASCADE; | |
CREATE TABLE accounts ( | |
id SERIAL PRIMARY KEY, | |
type account_type, | |
owner_id INTEGER NOT NULL, | |
balance NUMERIC DEFAULT 0.0, | |
CONSTRAINT positive_balance CHECK (balance >= 0), | |
FOREIGN KEY (owner_id) REFERENCES users (id) | |
); | |
DROP TYPE IF EXISTS ledger_type CASCADE; | |
CREATE TYPE ledger_type AS ENUM ('credit', 'debit'); | |
DROP TABLE IF EXISTS ledger; | |
CREATE TABLE ledger ( | |
id SERIAL PRIMARY KEY, | |
account_id INTEGER NOT NULL, | |
date DATE NOT NULL DEFAULT CURRENT_DATE, | |
type ledger_type NOT NULL, | |
amount NUMERIC NOT NULL, | |
FOREIGN KEY (account_id) REFERENCES accounts (id) | |
); | |
-- | |
-- Seed the database with some owner and account information | |
-- | |
INSERT INTO users (id, username, pin) VALUES | |
(1, 'alice', 1234), | |
(2, 'bob', 9999); | |
INSERT INTO accounts (id, type, owner_id, balance) VALUES | |
(1, 'checking', 1, 250.0), | |
(2, 'savings', 1, 5.00), | |
(3, 'checking', 2, 100.0), | |
(4, 'savings', 2, 2342.13); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from balance import connect, createdb | |
def verify_transaction(conn, context=False): | |
if context: | |
return verify_transaction_with_context(conn) | |
return verify_transaction_no_context(conn) | |
def verify_transaction_with_context(conn): | |
with conn.cursor() as curs: | |
# Execute command that raises a constraint | |
try: | |
curs.execute("UPDATE accounts SET balance=%s", (-130.935,)) | |
except Exception as e: | |
print(e) # Constraint exception | |
with conn.cursor() as curs: | |
try: | |
curs.execute("SELECT id, type FROM accounts WHERE owner_id=%s", (1,)) | |
except pg.InternalError as e: | |
print(e) | |
def verify_transaction_no_context(conn): | |
curs = conn.cursor() | |
try: | |
# Execute a command that will raise a constraint | |
curs.execute("UPDATE accounts SET balance=%s", (-130.935,)) | |
except Exception as e: | |
print(e) # Constraint exception | |
# Execute another command, but because of the previous exception: | |
curs = conn.cursor() | |
try: | |
curs.execute("SELECT id, type FROM accounts WHERE owner_id=%s", (1,)) | |
except pg.InternalError as e: | |
print(e) | |
if __name__ == '__main__': | |
conn = connect() | |
createdb(conn) | |
# Verify transactions | |
verify_transaction(conn, True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment