Skip to content

Instantly share code, notes, and snippets.

@ryancollingwood
Created November 2, 2023 06:43
Show Gist options
  • Save ryancollingwood/1bc64520adac3f6c53edfaf330ff4922 to your computer and use it in GitHub Desktop.
Save ryancollingwood/1bc64520adac3f6c53edfaf330ff4922 to your computer and use it in GitHub Desktop.
For column that is a composite of other columns in SQL query (sub queries or CTEs), get the sql that makes up that column
from typing import Dict
import re
from dataclasses import dataclass, field
import sqlparse
@dataclass
class SQLColumnLineage():
sql: str
target_column: str
aliased_columns: Dict[str, str] = field(default_factory=lambda: dict())
def get_columns_referenced(self, token):
# Extract column references from the parsed SQL
column_references = set()
if token.is_whitespace:
return column_references
if isinstance(token, sqlparse.sql.IdentifierList):
# Find column references in the identifier list
for identifier in token.get_identifiers():
column_references.update(self.extract_column_references(identifier))
elif isinstance(token, sqlparse.sql.Identifier):
# Find column references in a single identifier
column_references.update(self.extract_column_references(token))
for t in token.tokens:
column_references.update(self.get_columns_referenced(t))
elif isinstance(token, sqlparse.sql.Parenthesis):
# Handle subqueries within parentheses
for t in token.tokens:
column_references.update(self.get_columns_referenced(t))
elif isinstance(token, sqlparse.sql.Statement):
# Recursively handle subqueries
for t in token.tokens:
column_references.update(self.get_columns_referenced(t))
else:
# handle the case that this token has child tokens that
# are potentially relevant
try:
sub_tokens = [x for x in token.tokens]
except AttributeError:
sub_tokens = set()
for t in sub_tokens:
if t.is_whitespace:
continue
column_references.update(self.get_columns_referenced(t))
return column_references
def extract_column_references(self, token):
column_name = self.target_column.lower()
result = set()
try:
referenced_column = token.get_real_name().lower()
except AttributeError:
return result
alias = token.get_alias()
if alias is None:
alias = ""
else:
alias = alias
if alias.lower() != referenced_column:
# to avoid the case of assigning a function or keyword
# as an alias, we only want Name token types
# To be fair this is a little sketchy
if str(token.tokens[0].ttype) == "Token.Name":
self.aliased_columns[referenced_column] = alias
else:
#print("not aliasing", token.value)
pass
if column_name in (referenced_column, alias.lower()):
try:
sub_tokens = [str(x) for x in token.flatten()]
# to handle the case of a subquery using aliased columns
# we want the actual source columns
replacement_sub_tokens = [self.aliased_columns[x.lower()] if x.lower() in self.aliased_columns else x for x in sub_tokens]
# may have to do this replacement a few times
# depending on how nested the aliasing has been
while replacement_sub_tokens != sub_tokens:
sub_tokens = replacement_sub_tokens
replacement_sub_tokens = [self.aliased_columns[x.lower()] if x.lower() in self.aliased_columns else x for x in sub_tokens]
result.add("".join(sub_tokens))
except AttributeError:
# if the token has no child just use the sql
# that describes the current token
result.add(token.value)
return result
def find_columns_in_query(self):
# Parse the SQL query
parsed = sqlparse.parse(self.sql)
# Extract column references from the parsed SQL
column_references = set()
for stmt in parsed:
for token in stmt.tokens:
column_references.update(self.get_columns_referenced(token))
column_name = self.target_column.lower()
# filter results
result = [re.sub(r"\s+", " ", x).strip() for x in column_references]
# remove self references
result = [x for x in result if x.strip().lower() != f"{column_name} as {column_name}"]
if len(result) == 1:
return result[0]
return result
if __name__ == "__main__":
sql = """SELECT *, CONCAT( COALESCE( cast(trim(SiteID) as string) , '') ,'--' ,COALESCE( cast(trim(SystemOfOriginID) as string), '') ) as zz_meta_key FROM reference_sites"""
target_column = "zz_meta_key"
finder = SQLColumnLineage(
sql = sql,
target_column = target_column,
)
# Get columns referenced in the SQL query for the specified column
referenced_columns = finder.find_columns_in_query()
print(f"Columns that make up {target_column}: {', '.join(referenced_columns)}")
@ryancollingwood
Copy link
Author

Leverages https://pypi.org/project/sqlparse/ to do the heavy lifting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment