Skip to content

Instantly share code, notes, and snippets.

@kmuthukk
Last active December 1, 2023 22:55
Show Gist options
  • Save kmuthukk/e5a9180fd7bf7ae43ea27d3535413db0 to your computer and use it in GitHub Desktop.
Save kmuthukk/e5a9180fd7bf7ae43ea27d3535413db0 to your computer and use it in GitHub Desktop.
# Dependencies:
# On CentOS you can install psycopg2 thus:
#
# sudo yum install postgresql-libs
# sudo yum install python-psycopg2
import psycopg2;
from psycopg2 import OperationalError, errorcodes, errors
import datetime;
import random;
import time;
from multiprocessing.dummy import Pool as ThreadPool
num_users=1000
num_accts_per_user=10
print_stats_every_n_rows=1000
num_update_threads=3
num_updates_per_thread=5000
num_select_threads=0
# pass multiple hosts this way, and the parallel worker threads will round-robin their connections between these.
# hosts = ['ip1', 'ip2', 'ip3']
hosts = ['127.0.0.1']
db_name = "my_db"
admin_connect_string="host={} dbname=yugabyte user=yugabyte port=5433"
connect_string="host={} dbname=" + db_name + " user=yugabyte port=5433"
def create_tables():
conn = psycopg2.connect(admin_connect_string.format(hosts[0]))
conn.set_session(autocommit=True)
cur = conn.cursor()
cur.execute("DROP DATABASE IF EXISTS {}".format(db_name))
print("dropped database if exists {}".format(db_name))
cur.execute("CREATE DATABASE {} COLOCATED=true".format(db_name))
print("created database {} COLOCATED=true".format(db_name))
# now connect to the created database
conn = psycopg2.connect(connect_string.format(hosts[0]))
conn.set_session(autocommit=True)
cur = conn.cursor()
cur.execute("""DROP TABLE IF EXISTS users""");
cur.execute("""DROP TABLE IF EXISTS accounts""");
cur.execute("""DROP TABLE IF EXISTS audit_table""");
cur.execute("""CREATE EXTENSION IF NOT EXISTS "uuid-ossp" """);
cur.execute("""CREATE TABLE IF NOT EXISTS users(
user_id character varying(24),
user_name text,
addr text,
bal numeric(16, 2),
PRIMARY KEY(user_id ASC))
""")
print("Created users table")
cur.execute("""CREATE TABLE IF NOT EXISTS accounts(
account_id character varying(24),
user_id character varying(24),
bal numeric(16, 2),
PRIMARY KEY(account_id ASC))
""")
print("Created accounts table")
cur.execute("""CREATE UNIQUE INDEX user_accounts ON accounts(user_id ASC, account_id ASC)""");
print("Created user_accounts index")
cur.execute("""CREATE TABLE IF NOT EXISTS audit_table(
account_id character varying(24),
event_id timestamp,
user_id character varying(24),
delta numeric(16, 2),
PRIMARY KEY(account_id ASC, event_id ASC))
""")
print("Created audit_table table")
print("====================")
def load_sample_data():
conn = psycopg2.connect(connect_string.format(hosts[0]))
conn.set_session(autocommit=True)
cur = conn.cursor()
# create a bunch of users with 0 balance
cur.execute("""
INSERT INTO users(user_id, user_name, addr, bal)
SELECT 'user-' || idx, 'username-' || idx, 'useraddr-' || idx, 0
FROM generate_series(0, %s - 1) idx
""", (num_users, ));
# create a bunch of accounts with 0 balance
total_accts = num_users*num_accts_per_user;
cur.execute("""
INSERT INTO accounts(account_id, user_id, bal)
SELECT 'user-' || (idx / %s) || '-account-' || mod(idx, %s), 'user-' || (idx / %s), 0
FROM generate_series(0, %s - 1) idx
""", (num_accts_per_user, num_accts_per_user, num_accts_per_user, total_accts, ));
def update_data_worker(thread_num):
try:
update_data_worker_inner(thread_num)
print("Thread {} exited cleanly".format(thread_num))
except Exception as e:
print("Thread {} had an exception: {}".format(thread_num, e))
def update_data_worker_inner(thread_num):
thread_id = str(thread_num)
conn = psycopg2.connect(connect_string.format(hosts[thread_num % len(hosts)]))
conn.set_session(autocommit=True)
cur = conn.cursor()
print("Thread-" + thread_id + ": Going to update %d rows..." % (num_updates_per_thread, ))
start_time = time.time()
retry_cnt = 0
for idx in range(num_updates_per_thread):
# pick a random user and a random account of that user
user_idx = random.randint(0, num_users-1)
acct_idx = random.randint(0, num_accts_per_user-1)
user_id = 'user-' + str(user_idx)
acct_id = user_id + '-account-' + str(acct_idx)
# start business transaction of incrementing the account balance,
# the balance for the user in the user table and add a row in the
# audit table.
retry = True
while (retry):
try:
cur.execute("""BEGIN TRANSACTION isolation level read committed""")
# cur.execute("""BEGIN TRANSACTION isolation level repeatable read""")
# cur.execute("""BEGIN TRANSACTION isolation level serializable""")
# business logic here; for now assume we always increment by 1
delta = 1
# do insert into audit_table early on as best practice
cur.execute("""INSERT INTO audit_table(account_id, event_id, user_id, delta) VALUES(%s, now(), %s, %s)""",
(acct_id, user_id, delta))
cur.execute("""SELECT bal FROM accounts WHERE account_id = %s""",
(acct_id, ))
cur_balance = cur.fetchone()[0]
# for now business logic is simply to increment the balance
new_balance = cur_balance + delta
cur.execute("""UPDATE accounts SET bal = %s WHERE account_id = %s and bal = %s""",
(new_balance, acct_id, cur_balance))
if (cur.rowcount == 0):
cur.execute("""ROLLBACK""")
retry = True
retry_cnt = retry_cnt + 1
print("Unexpected balance due to concurrent txn. Retrying: retry_cnt: {}; idx: {}".format(retry_cnt, idx))
else:
cur.execute("""UPDATE users SET bal = bal + %s WHERE user_id = %s""",
(delta, user_id))
cur.execute("""COMMIT""")
retry = False
except OperationalError as err:
cur.execute("""ROLLBACK""")
print("Exception PGCODE={}, PGERROR={}".format(err.pgcode, err.pgerror))
if (err.pgcode != "40001"):
retry = False
else:
retry_cnt = retry_cnt + 1
print("Going to retry; retry_cnt={}".format(retry_cnt))
if ((idx + 1) % print_stats_every_n_rows == 0):
end_time = time.time()
delta = end_time - start_time
print("Thread-" + thread_id + ": Updated %d rows" % (idx + 1))
print("Thread-" + thread_id + ": Avg business txn latency: %s ms" % ((delta) * 1000 / print_stats_every_n_rows))
start_time = time.time()
print("Thread-{}: Retries Needed: {}; Updated {} rows".format(thread_id, retry_cnt, num_updates_per_thread))
def check_work():
conn = psycopg2.connect(connect_string.format(hosts[0]))
conn.set_session(autocommit=True)
cur = conn.cursor()
# total number of updates expected in the audit_table
expected_updates = num_update_threads * num_updates_per_thread;
# Since each update increments balance by 1...
expected_balance = expected_updates
cur.execute("""SELECT sum(bal) FROM accounts""")
result = cur.fetchone()[0]
print("Sum of balances in accounts table is {}, should be {}".format(result, expected_balance))
cur.execute("""SELECT sum(bal) FROM users""")
result = cur.fetchone()[0]
print("Sum of balances in users table is {}, should be {}".format(result, expected_balance))
cur.execute("""SELECT count(*) FROM audit_table""")
result = cur.fetchone()[0]
print("Number of rows in audit_table is {}, should be {}".format(result, expected_updates))
def perform_work():
pool = ThreadPool(num_update_threads+num_select_threads)
if (num_update_threads > 0):
results1 = pool.map_async(update_data_worker, range(num_update_threads))
print("Launched {} UPDATE workload threads...".format(num_update_threads))
# if (num_select_threads > 0):
# results2 = pool.map_async(select_data_worker, range(num_select_threads))
# print("Launched {} SELECT workload threads...".format(num_select_threads))
print("Launched {} workload threads...".format(num_update_threads+num_select_threads))
pool.close();
pool.join()
print("All threads finished execution")
# Main
create_tables()
load_sample_data()
perform_work()
check_work()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment