Skip to content

Instantly share code, notes, and snippets.

@kmuthukk
Last active August 15, 2022 20:54
Show Gist options
  • Save kmuthukk/6280d3e0a88ee7ef4b352912d7720a49 to your computer and use it in GitHub Desktop.
Save kmuthukk/6280d3e0a88ee7ef4b352912d7720a49 to your computer and use it in GitHub Desktop.
# Centos 7 dependencies:
# sudo yum install -y postgresql-libs python-psycopg2
#
# For help:
# % python wide_rows.sql -h
#
# please fill out the connect string to connect it to a YugabyteDB instance.
import psycopg2
import argparse
import time
import random
import sys
from multiprocessing.dummy import Pool as ThreadPool
from functools import partial
connect_str="host=localhost dbname=yugabyte user=yugabyte port=5433"
parser=argparse.ArgumentParser()
parser.add_argument("--num_write_threads", help="Number of writer threads", type=int, default=4)
parser.add_argument("--num_rows_per_thread", help="Number of rows to be inserted per thread", type=int, default=500)
parser.add_argument("--num_columns_per_row", help="Number of columns per row", type=int, default=100)
parser.add_argument("--use_udt", help="Use UDT instead of individual columns", default=False, action="store_true")
args = parser.parse_args()
# Load Phase params
num_write_threads=args.num_write_threads
num_rows=args.num_rows_per_thread
num_columns=args.num_columns_per_row
use_udt=args.use_udt
print("Threads: {}, Rows per thread: {}, Num columnns per row: {}, Using UDT: {}"
.format(num_write_threads, num_rows, num_columns, use_udt));
def create_table():
conn = psycopg2.connect(connect_str)
conn.set_session(autocommit=True)
cur = conn.cursor()
print("dropping table and index")
cur.execute("""DROP TABLE IF EXISTS my_table""");
cur.execute("""DROP TYPE IF EXISTS my_udt""");
column_specs = ""
for idx in range(num_columns):
if (idx == 0):
separator = ""
else:
separator = ", "
column_specs = column_specs + separator + "col" + str(idx) + " integer";
if (use_udt):
cur.execute("CREATE TYPE my_udt AS (" + column_specs + ")");
cur.execute(""" CREATE TABLE my_table(cid text, udt_col my_udt,
PRIMARY KEY(cid)) SPLIT INTO 1 TABLETS""")
else:
cur.execute(""" CREATE TABLE my_table(cid text, """ +
column_specs +
""", PRIMARY KEY(cid)) SPLIT INTO 1 TABLETS""")
def load_data_worker(thread_num):
thread_id = str(thread_num)
conn = psycopg2.connect(connect_str)
conn.set_session(autocommit=True)
cur = conn.cursor()
print("loading data on thread: {}".format(thread_num))
if (use_udt):
column_names = ", udt_col"
else:
column_names = "";
for idx in range(num_columns):
column_names = column_names + ", " + "col" + str(idx)
for idx in range(num_rows):
column_values = ""
if (use_udt):
column_values = column_values + "ROW(";
for jdx in range(num_columns):
if (jdx == 0):
separator = ""
else:
separator = ", "
column_values = column_values + separator + str((idx*idx) + jdx)
if (use_udt):
column_values = column_values + ")";
cur.execute("""INSERT INTO my_table (cid""" + column_names + """) VALUES (%s, """ + column_values + """)""",
("user-"+thread_id+"-"+str(idx), ))
if ((idx + 1) % 500 == 0):
print("Loaded {} rows.".format(idx+1))
def load_data():
pool = ThreadPool(num_write_threads)
results = pool.map(partial(load_data_worker), range(num_write_threads))
print("Loaded total of {} rows.".format(num_rows*num_write_threads))
create_table()
load_data()
@kmuthukk
Copy link
Author

Use packed.sh as a sample driver script.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment