Skip to content

Instantly share code, notes, and snippets.

@selwyth
Created January 18, 2018 00:24
Show Gist options
  • Save selwyth/7a6bba3951177dc9599900066031aef5 to your computer and use it in GitHub Desktop.
Save selwyth/7a6bba3951177dc9599900066031aef5 to your computer and use it in GitHub Desktop.
Rewrite your PostgreSQL query utilizing WITH syntax into a more efficient one that utilizes temporary tables.
import argparse
import os
import re
def main(sql_file):
with open(sql_file, 'r') as f:
sql = f.read()
sql = sql.upper().replace('WITH', ',')
opens, closes = parse_sql_into_chunks(sql)
subquery_data = find_subquery_data(sql, opens, closes)
temp_table_queries = create_temp_tables(sql, subquery_data)
last_subquery_idx = find_select_query_start(sql, opens, closes)
print(('{}\n'
'{}').format(''.join(temp_table_queries),
sql[last_subquery_idx:]))
def find_outer_parentheses(sql):
opening_parenthesis = re.search('\(', sql)
if opening_parenthesis:
opening_parenthesis_idx = opening_parenthesis.start()
else:
raise KeyError
with_subqueries = 0
counter = -1
for char in sql[opening_parenthesis_idx:]:
counter += 1
if char == '(':
with_subqueries += 1
elif char == ')':
with_subqueries -= 1
if with_subqueries == 0:
break
return opening_parenthesis_idx, opening_parenthesis_idx + counter
def parse_sql_into_chunks(sql):
opens_idx = []
closes_idx = []
cursor = 0
sql_copy = sql
while len(sql_copy) > 0:
try:
o, c = find_outer_parentheses(sql_copy)
except KeyError:
break
opens_idx.append(o + cursor)
closes_idx.append(c + cursor)
cursor += c
sql_copy = sql[cursor:]
return opens_idx, closes_idx
def find_subquery_data(sql, opens, closes):
"""
Give SQL and the start and close of outer-level parentheses, return the
name of the subquery and the start and close index of each, discarding
the last piece that isn't a subquery.
"""
for o, c1, c in zip(opens, [0] + closes[:-2], closes):
w = re.search(r'(\w+) AS \(', sql[c1:(o + 1)])
if w:
w = w.group(1)
yield w, o, c
def find_select_query_start(sql, opens, closes):
for o, c1, c in zip(opens, [0] + closes[:-2], closes):
w = re.search(r'(\w+) AS \(', sql[c1:(o + 1)], flags=re.IGNORECASE)
if not w:
return c
def create_temp_tables(sql, subquery_data):
for s in subquery_data:
q = ('DROP TABLE IF EXISTS {subquery_name} CASCADE;\n'
'CREATE LOCAL TEMPORARY TABLE {subquery_name} ON COMMIT PRESERVE ROWS AS\n'
'{subquery}\n'
';\n'.format(subquery_name=s[0],
subquery=sql[(s[1] + 1):s[2]]))
yield q
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('file', type=str, help='Path to the SQL file')
cmd_args = parser.parse_args()
main(cmd_args.file)
@selwyth
Copy link
Author

selwyth commented Jan 18, 2018

python with_query_rewriter.py /path/to/long_sql_file.sql

WITH hello AS (
    SELECT 1 FROM users
)

, world AS (
    SELECT 2 FROM addresses
)

SELECT 'boom'
FROM hello
CROSS JOIN world

becomes:

DROP TABLE IF EXISTS hello CASCADE;
CREATE LOCAL TEMPORARY TABLE hello ON COMMIT PRESERVE ROWS AS

    SELECT 1 FROM users

;

DROP TABLE IF EXISTS world CASCADE;
CREATE LOCAL TEMPORARY TABLE world ON COMMIT PRESERVE ROWS AS

    SELECT 1 FROM addresses

;

SELECT 'boom'
FROM hello
CROSS JOIN world

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment