Skip to content

Instantly share code, notes, and snippets.

@ian-whitestone
Last active May 3, 2021 12:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ian-whitestone/efaa62f21a265b4bf8c79f831ca302e3 to your computer and use it in GitHub Desktop.
Save ian-whitestone/efaa62f21a265b4bf8c79f831ca302e3 to your computer and use it in GitHub Desktop.
Code for the "Testing SQL" blog post: http://ianwhitestone.work/testing-sql
import os
from jinja2 import Environment, meta, Template
import pandas as pd
from pandas.testing import assert_frame_equal
import pandas.io.sql as psql
import psycopg2
# All these constants would likely live in separate files
CONNECTION = psycopg2.connect(
host=os.environ['PG_HOST'],
port=os.environ['PG_PORT'],
database=os.environ['PG_DBNAME'],
user=os.environ['PG_USER'],
password=os.environ['PG_PASSWORD']
)
# Jinja variable name -> production table name mapping
TABLE_XREF = {
'transactions': 'transactions',
'users': 'users',
}
BASE_SQL = """
WITH
results AS (
SELECT
DATE_PART('doy', t.processed_at) AS day_of_year,
u.country,
SUM(amount) AS trxn_volume
FROM
{{ transactions }} AS t
INNER JOIN {{ users }} AS u
ON t.user_id=u.id
GROUP BY 1,2
ORDER BY 1,2
)
SELECT
*
FROM
results
"""
# I opted to store the test data in strings as this was the quickest
# approach to get a working proof of concept. Storage in other
# formats like dictionaries or CSVs may be easier for users to
# create and manipulate, but would require some more functionality
# to ingest & convert to the corresponding SQL code & data types
TEST_DATA = {
'users': {
"column_names": ['id', 'country'],
"values": [
"(1, 'US')",
"(2, 'CA')",
"(3, 'CA')",
]
},
'transactions': {
"column_names": ['id', 'user_id', 'amount', 'processed_at'],
"values": [
"(1, 1, 15.0, TIMESTAMP'2020-01-01 12:05')",
"(2, 1, 10.49, TIMESTAMP'2020-01-01 12:10')",
"(3, 1, -10.49, TIMESTAMP'2020-01-01 12:15')",
"(4, 2, 25.99, TIMESTAMP'2020-01-02 15:25')",
"(5, 2, 5.45, TIMESTAMP'2020-01-05 14:01')",
"(6, 2, 50.5, TIMESTAMP'2020-01-07 03:45')",
"(7, 3, 49.5, TIMESTAMP'2020-01-07 22:45')",
]
}
}
EXPECTED_RESULTS = {
'day_of_year': [1, 2, 5, 7],
'country': ['US', 'CA', 'CA', 'CA'],
'trxn_volume': [15, 25.99, 5.45, 100]
}
def build_cte(table_ref, table_name):
values = ",\n".join(TEST_DATA[table_ref]['values'])
column_names = ",".join(TEST_DATA[table_ref]['column_names'])
cte = f"""
{table_name} AS (
SELECT * FROM (
VALUES \n{values}
) AS t ({column_names})
),
"""
return cte
def inject_cte(sql, cte):
"""
Add the CTE directly after the WITH statement.
Could add handling if SQL does not already have a WITH.
"""
assert sql.strip().startswith('WITH')
sql_parts = sql.split('WITH')
return f"WITH{cte}" + sql_parts[1]
def render_sql(mode):
sql_template = Template(BASE_SQL)
ast = Environment().parse(BASE_SQL)
jinja_table_refs = meta.find_undeclared_variables(ast)
table_prefix = ''
if mode == 'test':
# consider generating a random string in case there
# could actually be a table named test_users or test_transactions
table_prefix = 'test_'
# map Jinja table reference to actual table (or CTE) name
table_mapping = {
table_ref: f"{table_prefix}{TABLE_XREF[table_ref]}"
for table_ref in jinja_table_refs
}
sql = sql_template.render(**table_mapping)
if mode == 'test':
# create & inject the CTEs containing fake data into the SQL
for table_ref, table_name in table_mapping.items():
cte = build_cte(table_ref, table_name)
sql = inject_cte(sql, cte)
return sql
def run_pipeline():
sql = render_sql(mode='production')
print(f"Executing SQL:\n{sql}")
df = psql.read_sql(sql, CONNECTION)
# run rest of pipeline that relies on data
# ...
def run_sql_tests():
sql = render_sql(mode='test')
print(f"Executing SQL:\n{sql}")
actual_df = psql.read_sql(sql, CONNECTION)
expected_df = pd.DataFrame(EXPECTED_RESULTS)
print(f'Actual dataframe is:\n {actual_df}\nExpected dataframe is:\n{expected_df}')
# more work required for type checking - need to specify types in the expected results
assert_frame_equal(actual_df, expected_df, check_dtype=False)
print('Matchy matchy ✨!')
# For more advanced comparisons of two dataframes, check out:
# https://capitalone.github.io/datacompy/
@ian-whitestone
Copy link
Author

ian-whitestone commented May 2, 2021

If you have a Postgres database and the required environment variables (PG_HOST, etc.), you can run the tests on your database with python -c "from main import run_sql_tests as run; run()".

I was using this script on a Mac with BigSur:

> sw_vers
ProductName:	macOS
ProductVersion:	11.2.3
BuildVersion:	20D91

and:

  • Python 3.7.6
  • pandas==1.1.5
  • psycopg2==2.8.6
  • psycopg2-binary==2.8.6
  • Jinja2==2.11.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment