Created
August 21, 2024 05:42
-
-
Save a1ea321/ab5e2030ef1bf665c5e71d2f9e3860dc to your computer and use it in GitHub Desktop.
Outputs a set of tables (or table-like stuff such as views) that are mentioned in a query. "dep" for "dependency".
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from typing import Set, Iterator, Tuple, Optional | |
import mo_sql_parsing | |
def deps(query:str, db='', debug=False) -> Set[str]: | |
if query.strip() == '': | |
if debug: | |
return {}, set() | |
return set() | |
if db == 'sqlserver': | |
parsed = mo_sql_parsing.parse_sqlserver(query) | |
else: | |
# PostgreSQL is parsed here. | |
parsed = mo_sql_parsing.parse(query) | |
deps, undeps = set(), set() | |
for dep, undep in traverse(parsed): | |
deps.add(dep) | |
undeps.add(undep) | |
result = deps - undeps | |
if debug: | |
return parsed, result | |
return result | |
### THE REST IS NOT MEANT TO BE CALLED FROM OUTSIDE THIS FILE. | |
### DO THIS IN YOUR FILE: from sql_deps import deps [as something] | |
context = [] | |
def traverse(qp) -> Iterator[Tuple[Optional[str], Optional[str]]]: | |
global context | |
logging.info('>'.join(context) + f':{qp}' if type(qp) == str else '') | |
if type(qp) in [int, float, bool]: | |
pass | |
elif type(qp) == str: | |
if is_join(qp) or is_from(qp): | |
# This is a dependency. | |
yield qp, None | |
if is_with(qp): | |
# This is also a dependency but it is defined in the query (as a | |
# CTE). So it has to be removed from the final list. This is | |
# signified by yielding it on the other side. | |
yield None, qp | |
elif type(qp) == list: | |
for elem in qp: | |
yield from traverse(elem) | |
elif type(qp) == dict: | |
for k, v in qp.items(): | |
context.append(k) | |
yield from traverse(v) | |
context.pop() | |
else: | |
msg = f'Unexpected type {type(qp)} traversed. Query part:\n{qp}' | |
raise ValueError(msg) | |
def is_from(qp): | |
if context[-1].endswith('from'): | |
return True | |
if context[-1] == 'value' and context[-2].endswith('from'): | |
return True | |
return False | |
def is_join(qp): | |
if context[-1].endswith('join'): | |
return True | |
if context[-1] == 'value' and context[-2].endswith('join'): | |
return True | |
return False | |
def is_with(qp): | |
if context[-1] != 'name': | |
return False | |
return context[-2] == 'with' | |
### FROM HERE ON IS NOT PART OF THE FILE. IT "RUNS" THE FILE. | |
if __name__ == '__main__': | |
for d in deps(query=''' | |
with tmp_tbl as ( | |
select * from schema2.table6 | |
) | |
select * | |
from table1 t1 | |
join schema1.table2 t2 on t1.x = t2.id | |
join tmp_tbl t3 on t1.p = t3.p | |
where table1.z > 10 | |
and table1.w in ( | |
select w | |
from schema3.table4 | |
where z < 10 | |
) | |
'''): | |
print(d) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
qp
stands for "query part"