Created
August 18, 2020 01:40
-
-
Save muness/9506def0f96c28194e9cae73a1a2e3a7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CREATE LIBRARY sqlparse LANGUAGE plpythonu FROM 's3://zapier-data-packages/sqlparse.zip' CREDENTIALS 'aws_access_key_id=AKIA6P3ALSQUD7YBV6UK ;aws_secret_access_key=6FzdA0ldVGRuuBqovXpYPZsBLfZN9QAQ8eQbxaK5'; | |
DROP FUNCTION sqlparse_test(ANYELEMENT); | |
CREATE OR REPLACE FUNCTION sqlparse_test(sql_string ANYELEMENT) RETURNS varchar | |
IMMUTABLE | |
language plpythonu | |
AS | |
$$ | |
import sqlparse | |
return (sqlparse.parse(sql_string)[0]).value | |
$$; | |
select sqlparse_test('select * from x'::varchar(255)); | |
CREATE OR REPLACE FUNCTION public.sql_metadata_get_query_tables(sql_string ANYELEMENT) RETURNS varchar(max) | |
IMMUTABLE | |
language plpythonu | |
AS | |
$$ | |
import re | |
import sqlparse | |
from sqlparse.sql import TokenList | |
from sqlparse.tokens import Name, Whitespace, Wildcard, Number, Punctuation | |
def preprocess_query(query): | |
""" | |
Perform initial query cleanup | |
:type query str | |
:rtype str | |
""" | |
# 0. remove Looker prefix comment | |
query = re.sub(r"--\s*looker.*}'", ' ', query) | |
# 0. remove newlines | |
query = query.replace('\n', ' ') | |
# 1. remove aliases | |
# FROM `dimension_wikis` `dw` | |
# INNER JOIN `fact_wam_scores` `fwN` | |
# query = re.sub(r'(\s(FROM|JOIN)\s`[^`]+`)\s`[^`]+`', r'\1', query, flags=re.IGNORECASE) | |
# 2. `database`.`table` notation -> database.table | |
query = re.sub(r'`([^`]+)`\.`([^`]+)`', r'\1.\2', query) | |
# 2. database.table notation -> table | |
# query = re.sub(r'([a-z_0-9]+)\.([a-z_0-9]+)', r'\2', query, flags=re.IGNORECASE) | |
return query | |
def get_query_tokens(query): | |
""" | |
:type query str | |
:rtype: list[sqlparse.sql.Token] | |
""" | |
query = preprocess_query(query) | |
parsed = sqlparse.parse(query) | |
# handle empty queries (#12) | |
if not parsed: | |
return [] | |
tokens = TokenList(parsed[0].tokens).flatten() | |
# print([(token.value, token.ttype) for token in tokens]) | |
return [token for token in tokens if token.ttype is not Whitespace] | |
def _update_table_names(tables, | |
tokens, | |
index, | |
last_keyword): | |
""" | |
Return new table names matching database.table or database.schema.table notation | |
:type tables list[str] | |
:type tokens list[sqlparse.sql.Token] | |
:type index int | |
:type last_keyword str | |
:rtype: list[str] | |
""" | |
token = tokens[index] | |
last_token = tokens[index - 1].value.upper() if index > 0 else None | |
next_token = tokens[index + 1].value.upper() if index + 1 < len(tokens) else None | |
if last_keyword in ['FROM', 'JOIN', 'INNER JOIN', 'FULL JOIN', 'FULL OUTER JOIN', | |
'LEFT JOIN', 'RIGHT JOIN', | |
'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', | |
'INTO', 'UPDATE', 'TABLE'] \ | |
and last_token not in ['AS'] \ | |
and token.value not in ['AS', 'SELECT']: | |
if last_token == '.' and next_token != '.': | |
# we have database.table notation example | |
table_name = '{}.{}'.format(tokens[index - 2], tokens[index]) | |
if len(tables) > 0: | |
tables[-1] = table_name | |
else: | |
tables.append(table_name) | |
schema_notation_match = (Name, '.', Name, '.', Name) | |
schema_notation_tokens = (tokens[index - 4].ttype, | |
tokens[index - 3].value, | |
tokens[index - 2].ttype, | |
tokens[index - 1].value, | |
tokens[index].ttype) if len(tokens) > 4 else None | |
if schema_notation_tokens == schema_notation_match: | |
# we have database.schema.table notation example | |
table_name = '{}.{}.{}'.format( | |
tokens[index - 4], tokens[index - 2], tokens[index]) | |
if len(tables) > 0: | |
tables[-1] = table_name | |
else: | |
tables.append(table_name) | |
elif tokens[index - 1].value.upper() not in [',', last_keyword]: | |
# it's not a list of tables, e.g. SELECT * FROM foo, bar | |
# hence, it can be the case of alias without AS, e.g. SELECT * FROM foo bar | |
pass | |
else: | |
table_name = str(token.value.strip('`')) | |
tables.append(table_name) | |
return tables | |
def get_query_tables(query): | |
""" | |
:type query str | |
:rtype: list[str] | |
""" | |
tables = [] | |
last_keyword = None | |
table_syntax_keywords = [ | |
# SELECT queries | |
'FROM', 'WHERE', 'JOIN', 'INNER JOIN', 'FULL JOIN', 'FULL OUTER JOIN', | |
'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', | |
'LEFT JOIN', 'RIGHT JOIN', 'ON', | |
# INSERT queries | |
'INTO', 'VALUES', | |
# UPDATE queries | |
'UPDATE', 'SET', | |
# Hive queries | |
'TABLE', # INSERT TABLE | |
] | |
# print(query, get_query_tokens(query)) | |
query = query.replace('"', '') | |
tokens = get_query_tokens(query) | |
for index, token in enumerate(tokens): | |
# print([token, token.ttype, last_token, last_keyword]) | |
if token.is_keyword and token.value.upper() in table_syntax_keywords: | |
# keep the name of the last keyword, the next one can be a table name | |
last_keyword = token.value.upper() | |
# print('keyword', last_keyword) | |
elif str(token) == '(': | |
# reset the last_keyword for INSERT `foo` VALUES(id, bar) ... | |
last_keyword = None | |
elif token.is_keyword and str(token) in ['FORCE', 'ORDER', 'GROUP BY']: | |
# reset the last_keyword for queries like: | |
# "SELECT x FORCE INDEX" | |
# "SELECT x ORDER BY" | |
# "SELECT x FROM y GROUP BY x" | |
last_keyword = None | |
elif token.is_keyword and str(token) == 'SELECT' and last_keyword in ['INTO', 'TABLE']: | |
# reset the last_keyword for "INSERT INTO SELECT" and "INSERT TABLE SELECT" queries | |
last_keyword = None | |
elif token.ttype is Name or token.is_keyword: | |
tables = _update_table_names(tables, tokens, index, last_keyword) | |
return ','.join(tables) | |
try: | |
return get_query_tables(sql_string)[0:65535] | |
except: | |
return '<parse error>' | |
$$; | |
CREATE OR REPLACE FUNCTION public.sql_metadata_get_query_columns(sql_string ANYELEMENT) RETURNS varchar(max) | |
IMMUTABLE | |
language plpythonu | |
AS | |
$$ | |
import re | |
import sqlparse | |
from sqlparse.sql import TokenList | |
from sqlparse.tokens import Name, Whitespace, Wildcard, Number, Punctuation | |
def preprocess_query(query): | |
""" | |
Perform initial query cleanup | |
:type query str | |
:rtype str | |
""" | |
# 0. remove Looker prefix comment | |
query = re.sub(r"--\s*looker.*}'", ' ', query) | |
# 0. remove newlines | |
query = query.replace('\n', ' ') | |
# 1. remove aliases | |
# FROM `dimension_wikis` `dw` | |
# INNER JOIN `fact_wam_scores` `fwN` | |
# query = re.sub(r'(\s(FROM|JOIN)\s`[^`]+`)\s`[^`]+`', r'\1', query, flags=re.IGNORECASE) | |
# 2. `database`.`table` notation -> database.table | |
query = re.sub(r'`([^`]+)`\.`([^`]+)`', r'\1.\2', query) | |
# 2. database.table notation -> table | |
# query = re.sub(r'([a-z_0-9]+)\.([a-z_0-9]+)', r'\2', query, flags=re.IGNORECASE) | |
return query | |
def get_query_tokens(query): | |
""" | |
:type query str | |
:rtype: list[sqlparse.sql.Token] | |
""" | |
query = preprocess_query(query) | |
parsed = sqlparse.parse(query) | |
# handle empty queries (#12) | |
if not parsed: | |
return [] | |
tokens = TokenList(parsed[0].tokens).flatten() | |
# print([(token.value, token.ttype) for token in tokens]) | |
return [token for token in tokens if token.ttype is not Whitespace] | |
def get_query_columns(query): | |
""" | |
:type query str | |
:rtype: list[str] | |
""" | |
columns = [] | |
last_keyword = None | |
last_token = None | |
# print(preprocess_query(query)) | |
# these keywords should not change the state of a parser | |
# and not "reset" previously found SELECT keyword | |
keywords_ignored = ['AS', 'AND', 'OR', 'IN', 'IS', 'NOT', 'NOT NULL', 'LIKE', 'CASE', 'WHEN'] | |
# these function should be ignored | |
# and not "reset" previously found SELECT keyword | |
functions_ignored = ['COUNT', 'MIN', 'MAX', 'FROM_UNIXTIME', 'DATE_FORMAT', 'CAST', 'CONVERT', 'CONVERT_TIMEZONE', 'COALESCE', 'NVL', 'SUM', 'MOD', 'EXTRACT', 'TO_CHAR', 'DATE_TRUNC', 'RANK', 'REGEXP_REPLACE', 'DENSE_RANK', 'Z___MIN_RANK', 'Z___RANK' ] | |
for token in get_query_tokens(query): | |
if token.is_keyword and token.value.upper() not in keywords_ignored: | |
# keep the name of the last keyword, e.g. SELECT, FROM, WHERE, (ORDER) BY | |
last_keyword = token.value.upper() | |
# print('keyword', last_keyword) | |
elif token.ttype is Name: | |
# analyze the name tokens, column names and where condition values | |
if last_keyword in ['SELECT', 'WHERE', 'ORDER BY', 'ON'] \ | |
and last_token.value.upper() not in ['AS']: | |
# print(last_keyword, last_token, token.value) | |
if token.value.upper() not in functions_ignored: | |
if str(last_token) == '.': | |
# print('DOT', last_token, columns[-1]) | |
# we have table.column notation example | |
# append column name to the last entry of columns | |
# as it is a table name in fact | |
table_name = columns[-1] | |
columns[-1] = '{}.{}'.format(table_name, token) | |
else: | |
columns.append(str(token.value)) | |
elif last_keyword in ['INTO'] and last_token.ttype is Punctuation: | |
# INSERT INTO `foo` (col1, `col2`) VALUES (..) | |
# print(last_keyword, token, last_token) | |
columns.append(str(token.value).strip('`')) | |
elif token.ttype is Wildcard: | |
# handle * wildcard in SELECT part, but ignore count(*) | |
# print(last_keyword, last_token, token.value) | |
if last_keyword == 'SELECT' and last_token.value != '(': | |
if str(last_token) == '.': | |
# handle SELECT foo.* | |
table_name = columns[-1] | |
columns[-1] = '{}.{}'.format(table_name, str(token)) | |
else: | |
columns.append(str(token.value)) | |
last_token = token | |
return ','.join(columns) | |
try: | |
return get_query_columns(sql_string)[0:65535] | |
except: | |
return '<parse error>' | |
$$; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment