Last active
December 1, 2023 22:55
-
-
Save kmuthukk/e5a9180fd7bf7ae43ea27d3535413db0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Dependencies: | |
# On CentOS you can install psycopg2 thus: | |
# | |
# sudo yum install postgresql-libs | |
# sudo yum install python-psycopg2 | |
import psycopg2; | |
from psycopg2 import OperationalError, errorcodes, errors | |
import datetime; | |
import random; | |
import time; | |
from multiprocessing.dummy import Pool as ThreadPool | |
num_users=1000 | |
num_accts_per_user=10 | |
print_stats_every_n_rows=1000 | |
num_update_threads=3 | |
num_updates_per_thread=5000 | |
num_select_threads=0 | |
# pass multiple hosts this way, and the parallel worker threads will round-robin their connections between these. | |
# hosts = ['ip1', 'ip2', 'ip3'] | |
hosts = ['127.0.0.1'] | |
db_name = "my_db" | |
admin_connect_string="host={} dbname=yugabyte user=yugabyte port=5433" | |
connect_string="host={} dbname=" + db_name + " user=yugabyte port=5433" | |
def create_tables(): | |
conn = psycopg2.connect(admin_connect_string.format(hosts[0])) | |
conn.set_session(autocommit=True) | |
cur = conn.cursor() | |
cur.execute("DROP DATABASE IF EXISTS {}".format(db_name)) | |
print("dropped database if exists {}".format(db_name)) | |
cur.execute("CREATE DATABASE {} COLOCATED=true".format(db_name)) | |
print("created database {} COLOCATED=true".format(db_name)) | |
# now connect to the created database | |
conn = psycopg2.connect(connect_string.format(hosts[0])) | |
conn.set_session(autocommit=True) | |
cur = conn.cursor() | |
cur.execute("""DROP TABLE IF EXISTS users"""); | |
cur.execute("""DROP TABLE IF EXISTS accounts"""); | |
cur.execute("""DROP TABLE IF EXISTS audit_table"""); | |
cur.execute("""CREATE EXTENSION IF NOT EXISTS "uuid-ossp" """); | |
cur.execute("""CREATE TABLE IF NOT EXISTS users( | |
user_id character varying(24), | |
user_name text, | |
addr text, | |
bal numeric(16, 2), | |
PRIMARY KEY(user_id ASC)) | |
""") | |
print("Created users table") | |
cur.execute("""CREATE TABLE IF NOT EXISTS accounts( | |
account_id character varying(24), | |
user_id character varying(24), | |
bal numeric(16, 2), | |
PRIMARY KEY(account_id ASC)) | |
""") | |
print("Created accounts table") | |
cur.execute("""CREATE UNIQUE INDEX user_accounts ON accounts(user_id ASC, account_id ASC)"""); | |
print("Created user_accounts index") | |
cur.execute("""CREATE TABLE IF NOT EXISTS audit_table( | |
account_id character varying(24), | |
event_id timestamp, | |
user_id character varying(24), | |
delta numeric(16, 2), | |
PRIMARY KEY(account_id ASC, event_id ASC)) | |
""") | |
print("Created audit_table table") | |
print("====================") | |
def load_sample_data(): | |
conn = psycopg2.connect(connect_string.format(hosts[0])) | |
conn.set_session(autocommit=True) | |
cur = conn.cursor() | |
# create a bunch of users with 0 balance | |
cur.execute(""" | |
INSERT INTO users(user_id, user_name, addr, bal) | |
SELECT 'user-' || idx, 'username-' || idx, 'useraddr-' || idx, 0 | |
FROM generate_series(0, %s - 1) idx | |
""", (num_users, )); | |
# create a bunch of accounts with 0 balance | |
total_accts = num_users*num_accts_per_user; | |
cur.execute(""" | |
INSERT INTO accounts(account_id, user_id, bal) | |
SELECT 'user-' || (idx / %s) || '-account-' || mod(idx, %s), 'user-' || (idx / %s), 0 | |
FROM generate_series(0, %s - 1) idx | |
""", (num_accts_per_user, num_accts_per_user, num_accts_per_user, total_accts, )); | |
def update_data_worker(thread_num): | |
try: | |
update_data_worker_inner(thread_num) | |
print("Thread {} exited cleanly".format(thread_num)) | |
except Exception as e: | |
print("Thread {} had an exception: {}".format(thread_num, e)) | |
def update_data_worker_inner(thread_num): | |
thread_id = str(thread_num) | |
conn = psycopg2.connect(connect_string.format(hosts[thread_num % len(hosts)])) | |
conn.set_session(autocommit=True) | |
cur = conn.cursor() | |
print("Thread-" + thread_id + ": Going to update %d rows..." % (num_updates_per_thread, )) | |
start_time = time.time() | |
retry_cnt = 0 | |
for idx in range(num_updates_per_thread): | |
# pick a random user and a random account of that user | |
user_idx = random.randint(0, num_users-1) | |
acct_idx = random.randint(0, num_accts_per_user-1) | |
user_id = 'user-' + str(user_idx) | |
acct_id = user_id + '-account-' + str(acct_idx) | |
# start business transaction of incrementing the account balance, | |
# the balance for the user in the user table and add a row in the | |
# audit table. | |
retry = True | |
while (retry): | |
try: | |
cur.execute("""BEGIN TRANSACTION isolation level read committed""") | |
# cur.execute("""BEGIN TRANSACTION isolation level repeatable read""") | |
# cur.execute("""BEGIN TRANSACTION isolation level serializable""") | |
# business logic here; for now assume we always increment by 1 | |
delta = 1 | |
# do insert into audit_table early on as best practice | |
cur.execute("""INSERT INTO audit_table(account_id, event_id, user_id, delta) VALUES(%s, now(), %s, %s)""", | |
(acct_id, user_id, delta)) | |
cur.execute("""SELECT bal FROM accounts WHERE account_id = %s""", | |
(acct_id, )) | |
cur_balance = cur.fetchone()[0] | |
# for now business logic is simply to increment the balance | |
new_balance = cur_balance + delta | |
cur.execute("""UPDATE accounts SET bal = %s WHERE account_id = %s and bal = %s""", | |
(new_balance, acct_id, cur_balance)) | |
if (cur.rowcount == 0): | |
cur.execute("""ROLLBACK""") | |
retry = True | |
retry_cnt = retry_cnt + 1 | |
print("Unexpected balance due to concurrent txn. Retrying: retry_cnt: {}; idx: {}".format(retry_cnt, idx)) | |
else: | |
cur.execute("""UPDATE users SET bal = bal + %s WHERE user_id = %s""", | |
(delta, user_id)) | |
cur.execute("""COMMIT""") | |
retry = False | |
except OperationalError as err: | |
cur.execute("""ROLLBACK""") | |
print("Exception PGCODE={}, PGERROR={}".format(err.pgcode, err.pgerror)) | |
if (err.pgcode != "40001"): | |
retry = False | |
else: | |
retry_cnt = retry_cnt + 1 | |
print("Going to retry; retry_cnt={}".format(retry_cnt)) | |
if ((idx + 1) % print_stats_every_n_rows == 0): | |
end_time = time.time() | |
delta = end_time - start_time | |
print("Thread-" + thread_id + ": Updated %d rows" % (idx + 1)) | |
print("Thread-" + thread_id + ": Avg business txn latency: %s ms" % ((delta) * 1000 / print_stats_every_n_rows)) | |
start_time = time.time() | |
print("Thread-{}: Retries Needed: {}; Updated {} rows".format(thread_id, retry_cnt, num_updates_per_thread)) | |
def check_work(): | |
conn = psycopg2.connect(connect_string.format(hosts[0])) | |
conn.set_session(autocommit=True) | |
cur = conn.cursor() | |
# total number of updates expected in the audit_table | |
expected_updates = num_update_threads * num_updates_per_thread; | |
# Since each update increments balance by 1... | |
expected_balance = expected_updates | |
cur.execute("""SELECT sum(bal) FROM accounts""") | |
result = cur.fetchone()[0] | |
print("Sum of balances in accounts table is {}, should be {}".format(result, expected_balance)) | |
cur.execute("""SELECT sum(bal) FROM users""") | |
result = cur.fetchone()[0] | |
print("Sum of balances in users table is {}, should be {}".format(result, expected_balance)) | |
cur.execute("""SELECT count(*) FROM audit_table""") | |
result = cur.fetchone()[0] | |
print("Number of rows in audit_table is {}, should be {}".format(result, expected_updates)) | |
def perform_work(): | |
pool = ThreadPool(num_update_threads+num_select_threads) | |
if (num_update_threads > 0): | |
results1 = pool.map_async(update_data_worker, range(num_update_threads)) | |
print("Launched {} UPDATE workload threads...".format(num_update_threads)) | |
# if (num_select_threads > 0): | |
# results2 = pool.map_async(select_data_worker, range(num_select_threads)) | |
# print("Launched {} SELECT workload threads...".format(num_select_threads)) | |
print("Launched {} workload threads...".format(num_update_threads+num_select_threads)) | |
pool.close(); | |
pool.join() | |
print("All threads finished execution") | |
# Main | |
create_tables() | |
load_sample_data() | |
perform_work() | |
check_work() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment