Skip to content

Instantly share code, notes, and snippets.

@bbengfort
Created December 7, 2017 02:11
Show Gist options
  • Save bbengfort/936b4b3db9d81d27204a81f6ad816e5d to your computer and use it in GitHub Desktop.
Save bbengfort/936b4b3db9d81d27204a81f6ad816e5d to your computer and use it in GitHub Desktop.
Database transactions blog post.
#!/usr/bin/env python3
import os
import logging
import psycopg2 as pg
from decimal import Decimal
from functools import wraps
from psycopg2.pool import ThreadedConnectionPool
from contextlib import contextmanager
LOG_FORMAT = "%(asctime)s %(message)s"
MAX_DEPOSIT_LIMIT = 1000.00
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
log = logging.getLogger('balance')
def connect(env="DATABASE_URL", connections=2):
"""
Connect to the database using an environment variable.
"""
url = os.getenv(env)
if not url:
raise ValueError("no database url specified")
minconns = connections
maxconns = connections * 2
return ThreadedConnectionPool(minconns, maxconns, url)
pool = connect()
def createdb(conn, schema="schema.sql"):
"""
Execute DROP and CREATE TABLE statements in the specified SQL file.
"""
with open(schema, 'r') as f:
sql = f.read()
try:
with conn.cursor() as curs:
curs.execute(sql)
conn.commit()
except Exception as e:
conn.rollback()
raise e
@contextmanager
def transaction(name="transaction", **kwargs):
# Get the session parameters from the kwargs
options = {
"isolation_level": kwargs.get("isolation_level", None),
"readonly": kwargs.get("readonly", None),
"deferrable": kwargs.get("deferrable", None),
}
try:
conn = pool.getconn()
conn.set_session(**options)
yield conn
conn.commit()
except Exception as e:
conn.rollback()
log.error("{} error: {}".format(name, e))
finally:
conn.reset()
pool.putconn(conn)
def transact(func):
"""
Creates a connection per-transaction, committing when complete or
rolling back if there is an exception. It also ensures that the conn is
closed when we're done.
"""
@wraps(func)
def inner(*args, **kwargs):
with transaction(name=func.__name__) as conn:
func(conn, *args, **kwargs)
return inner
def authenticate(conn, user, pin, account=None):
"""
Returns an account id if the name is found and if the pin matches.
"""
with conn.cursor() as curs:
sql = "SELECT 1 AS authd FROM users WHERE username=%s AND pin=%s"
curs.execute(sql, (user, pin))
if curs.fetchone() is None:
raise ValueError("could not validate user via PIN")
return True
if account:
# Verify account ownership if account is provided
verify_account(conn, user, account)
def verify_account(conn, user, account):
"""
Verify that the account is held by the user.
"""
with conn.cursor() as curs:
sql = (
"SELECT 1 AS verified FROM accounts a "
"JOIN users u on u.id = a.owner_id "
"WHERE u.username=%s AND a.id=%s"
)
curs.execute(sql, (user, account))
if curs.fetchone() is None:
raise ValueError("account belonging to user not found")
return True
def ledger(conn, account, record, amount):
"""
Add a ledger record with the amount being credited or debited.
"""
# Perform the insert
with conn.cursor() as curs:
sql = "INSERT INTO ledger (account_id, type, amount) VALUES (%s, %s, %s)"
curs.execute(sql, (account, record, amount))
# If we are crediting the account, perform daily deposit verification
if record == "credit":
check_daily_deposit(conn, account)
def check_daily_deposit(conn, account):
"""
Raise an exception if the deposit limit has been exceeded.
"""
with conn.cursor() as curs:
sql = (
"SELECT amount FROM ledger "
"WHERE date=now()::date AND type='credit' AND account_id=%s"
)
curs.execute(sql, (account,))
total = sum(row[0] for row in curs.fetchall())
if total > MAX_DEPOSIT_LIMIT:
raise Exception("daily deposit limit has been exceeded!")
def update_balance(conn, account, amount):
"""
Add the amount (or subtract if negative) to the account balance.
"""
amount = Decimal(amount)
with conn.cursor() as curs:
current = balance(conn, account)
sql = "UPDATE accounts SET balance=%s WHERE id=%s"
curs.execute(sql, (current+amount, account))
def balance(conn, account):
with conn.cursor() as curs:
curs.execute("SELECT balance FROM accounts WHERE id=%s", (account,))
return curs.fetchone()[0]
@transact
def withdraw(conn, user, pin, account, amount):
# Step 1: authenticate the user via pin and verify account ownership
authenticate(conn, user, pin, account)
# Step 2: add the ledger record with the debit
ledger(conn, account, "debit", amount)
# Step 3: update the account value by subtracting the amount
update_balance(conn, account, amount * -1)
# Fetch the current balance in the account and log it
record = "withdraw ${:0.2f} from account {} | current balance: ${:0.2f}"
log.info(record.format(amount, account, balance(conn, account)))
@transact
def deposit(conn, user, pin, account, amount):
# Step 1: authenticate the user via pin and verify account ownership
authenticate(conn, user, pin, account)
# Step 2: add the ledger record with the credit
ledger(conn, account, "credit", amount)
# Step 3: update the account value by adding the amount
update_balance(conn, account, amount)
# Fetch the current balance in the account and log it
record = "withdraw ${:0.2f} from account {} | current balance: ${:0.2f}"
log.info(record.format(amount, account, balance(conn, account)))
def sequential():
# Successful deposit
deposit('alice', 1234, 1, 785.0)
# Unsuccessful authenticate
withdraw('bob', 8881, 2, 180.00)
# Successful withdrawal
withdraw('alice', 1234, 1, 230.0)
# Unsuccessful deposit
deposit('alice', 1234, 1, 489.0)
# Successful deposit
deposit('bob', 9999, 2, 220.23)
if __name__ == '__main__':
import time
import random
import threading
conn = pool.getconn()
createdb(conn)
pool.putconn(conn)
def op1():
time.sleep(random.random())
withdraw('alice', 1234, 1, 300.0)
def op2():
time.sleep(random.random())
deposit('alice', 1234, 1, 75.0)
withdraw('alice', 1234, 1, 25.0)
threads = [
threading.Thread(target=op1),
threading.Thread(target=op2),
]
for t in threads:
t.start()
for t in threads:
t.join()
--
-- Reset the database by dropping existing tables, then create the tables
-- required by the balance.py Python script.
--
DROP TABLE IF EXISTS users CASCADE;
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) UNIQUE,
pin SMALLINT NOT NULL
);
DROP TYPE IF EXISTS account_type CASCADE;
CREATE TYPE account_type AS ENUM ('checking', 'savings');
DROP TABLE IF EXISTS accounts CASCADE;
CREATE TABLE accounts (
id SERIAL PRIMARY KEY,
type account_type,
owner_id INTEGER NOT NULL,
balance NUMERIC DEFAULT 0.0,
CONSTRAINT positive_balance CHECK (balance >= 0),
FOREIGN KEY (owner_id) REFERENCES users (id)
);
DROP TYPE IF EXISTS ledger_type CASCADE;
CREATE TYPE ledger_type AS ENUM ('credit', 'debit');
DROP TABLE IF EXISTS ledger;
CREATE TABLE ledger (
id SERIAL PRIMARY KEY,
account_id INTEGER NOT NULL,
date DATE NOT NULL DEFAULT CURRENT_DATE,
type ledger_type NOT NULL,
amount NUMERIC NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts (id)
);
--
-- Seed the database with some owner and account information
--
INSERT INTO users (id, username, pin) VALUES
(1, 'alice', 1234),
(2, 'bob', 9999);
INSERT INTO accounts (id, type, owner_id, balance) VALUES
(1, 'checking', 1, 250.0),
(2, 'savings', 1, 5.00),
(3, 'checking', 2, 100.0),
(4, 'savings', 2, 2342.13);
from balance import connect, createdb
def verify_transaction(conn, context=False):
if context:
return verify_transaction_with_context(conn)
return verify_transaction_no_context(conn)
def verify_transaction_with_context(conn):
with conn.cursor() as curs:
# Execute command that raises a constraint
try:
curs.execute("UPDATE accounts SET balance=%s", (-130.935,))
except Exception as e:
print(e) # Constraint exception
with conn.cursor() as curs:
try:
curs.execute("SELECT id, type FROM accounts WHERE owner_id=%s", (1,))
except pg.InternalError as e:
print(e)
def verify_transaction_no_context(conn):
curs = conn.cursor()
try:
# Execute a command that will raise a constraint
curs.execute("UPDATE accounts SET balance=%s", (-130.935,))
except Exception as e:
print(e) # Constraint exception
# Execute another command, but because of the previous exception:
curs = conn.cursor()
try:
curs.execute("SELECT id, type FROM accounts WHERE owner_id=%s", (1,))
except pg.InternalError as e:
print(e)
if __name__ == '__main__':
conn = connect()
createdb(conn)
# Verify transactions
verify_transaction(conn, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment