Skip to content

Instantly share code, notes, and snippets.

@kmuthukk
Created August 23, 2020 03:26
Show Gist options
  • Save kmuthukk/07b88a99c4657e30d0ee17beda0157c4 to your computer and use it in GitHub Desktop.
Save kmuthukk/07b88a99c4657e30d0ee17beda0157c4 to your computer and use it in GitHub Desktop.
Loading CSV file to YB YSQL in smaller batches (rather than one large ACID transaction)
# sudo pip3 install pandas
# sudo pip3 install sqlalchemy
# sudo pip3 install psycopg2
# Imports
import pandas as pd
from sqlalchemy import create_engine
import psycopg2;
import time
# Params
dbname="yugabyte"
user="yugabyte"
password="yugabyte"
host="localhost"
port=5433
# You can download the sample data file from here:
# curl -s -O https://raw.githubusercontent.com/yugabyte/yugabyte-db/master/src/postgres/src/test/regress/data/airport-codes.csv
#
csv_file="airport-codes.csv"
def execute_ddl(host, dbname, user, password, port, ddl):
conn = psycopg2.connect("host={} dbname={} user={} password={} port={}"
.format(host, dbname, user, password, port))
conn.set_session(autocommit=True)
cur = conn.cursor()
cur.execute(ddl)
print("EXECUTED: {}".format(ddl))
print("====================")
def load_data(host, dbname, user, password, port, csv_file, chunksize):
# Instantiate sqlachemy.create_engine object
engine = create_engine("postgresql://{}:{}@{}:{}/{}".format(user, password, host, port, dbname))
# Create an iterable that will read chunksize rows at a time.
# Note: If the CSV doesn't have a header pass column names as an argument.
# For example,
# columns = ["sepal_length", "sepal_width", "petal_length", "petal_width", "class"]
# for data_frame in pd.read_csv(csv_file,names=columns,chunksize=1000):
# ...
#
for data_frame in pd.read_csv(csv_file,chunksize=chunksize):
data_frame.to_sql(
"airports", # table name
engine,
index=False,
if_exists="append" # if the table already exists, append this data
)
drop_table_ddl = """DROP TABLE IF EXISTS airports"""
create_table_ddl = """CREATE TABLE airports(
ident TEXT,
type TEXT,
name TEXT,
elevation_ft INT,
continent TEXT,
iso_country CHAR(2),
iso_region CHAR(7),
municipality TEXT,
gps_code TEXT,
iata_code TEXT,
local_code TEXT,
coordinates TEXT,
PRIMARY KEY (ident))"""
idx_ddl = "CREATE INDEX airport_type_region_idx ON airports((type, iso_region) HASH, ident ASC)"
execute_ddl(host, dbname, user, password, port, drop_table_ddl)
execute_ddl(host, dbname, user, password, port, create_table_ddl)
execute_ddl(host, dbname, user, password, port, idx_ddl)
start = time.time()
load_data(host, dbname, user, password, port, csv_file, chunksize=256)
now = time.time()
print("Time: %s secs" % (now - start))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment