Created
February 25, 2021 01:20
-
-
Save jeremystan/d09e0054237675ac59c8e68d4c34b002 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is an example python class that can be used to create SQL that will pivot data SQL | |
query-able database (such as Redshift). | |
Note that it uses an Anomalo library, dquality.db, which is not open-sourced. However, | |
you can replace the calls to self.db and DB with the appropriate calls for your warehouse | |
to construct and execute SQL queries. | |
""" | |
import logging | |
import pandas as pd | |
import dquality.db as DB | |
class Pivot: | |
""" | |
Pivot a table from "long" into "wide" format | |
For a given table, reduce it to one row per "index" value, | |
with N columns defined by the distinct values in "column", | |
each column containing the values in "value". | |
In addition, will retain the first element of columns in | |
"first" (along with the "index"). | |
""" | |
def __init__(self, db, table, indexes, column, value, first=None): | |
self.db = db | |
self.table = table | |
self.column = column | |
self.value = value | |
if not isinstance(indexes, list): | |
self.indexes = [indexes] | |
else: | |
self.indexes = indexes | |
if first is None: | |
first = [] | |
elif not isinstance(first, list): | |
first = [first] | |
self.first = first | |
@property | |
def items(self): | |
sql = """ | |
select distinct {column} | |
from {table} | |
""" | |
query = self.db.make_sql_from_template( | |
sql, {"table": self.table, "column": self.column} | |
) | |
result = DB.rows_to_pandas(self.db.query(query)) | |
items = result[self.column].values | |
return items | |
@property | |
def names(self): | |
names = [DB.safe_column_name(v, log=logging).lower() for v in self.items] | |
counts = pd.Series(names).value_counts() | |
if any(counts > 1): | |
dupes = set(counts[counts > 1].index.values) | |
raise ValueError( | |
"pivot column names will not be distinct: {}".format(dupes) | |
) | |
return names | |
@property | |
def index_sql(self): | |
return ", ".join(self.indexes) | |
def base_sql(self): | |
first_sql = "" | |
for col in self.first: | |
sql = """ | |
, first_value({col}) over ( | |
partition by {index_sql} | |
order by 1 | |
rows between unbounded preceding and current row | |
) as {col}""" | |
first_sql += self.db.make_sql_from_template( | |
sql, {"col": col, "index_sql": self.index_sql} | |
) | |
sql = """\ | |
select distinct {index_sql}\ | |
{first_sql} | |
from {table}""" | |
query = self.db.make_sql_from_template( | |
sql, | |
{"table": self.table, "index_sql": self.index_sql, "first_sql": first_sql}, | |
) | |
return query | |
def max_sql(self): | |
final = [] | |
for item, name in zip(self.items, self.names): | |
sql = "max(case when {column} = '{item}' then {value} else NULL end) as {name}" | |
final.append( | |
self.db.make_sql_from_template( | |
sql, | |
{ | |
"name": name, | |
"item": item, | |
"value": self.value, | |
"column": self.column, | |
}, | |
) | |
) | |
return ",".join(final) | |
def pivot_sql(self): | |
sql = """\ | |
select | |
{index_sql}, | |
{max_sql} | |
from | |
{table} | |
group by | |
{index_sql} | |
""" | |
return self.db.make_sql_from_template( | |
sql, | |
{ | |
"index_sql": self.index_sql, | |
"max_sql": self.max_sql(), | |
"table": self.table, | |
}, | |
) | |
@property | |
def sql(self): | |
# Final query | |
sql = """ | |
with base as ( | |
{base} | |
), pivot as ( | |
{pivot} | |
) | |
select | |
base.*, | |
{pivot_select} | |
from | |
base, pivot | |
where | |
{join} | |
""" | |
join_sql_list = ["base.{i} = pivot.{i}".format(i=i) for i in self.indexes] | |
join_sql = " and ".join(join_sql_list) | |
pivot_select_list = ["pivot.{}".format(n) for n in self.names] | |
pivot_select_sql = ", ".join(pivot_select_list) | |
query = self.db.make_sql_from_template( | |
sql, | |
{ | |
"base": self.base_sql(), | |
"pivot": self.pivot_sql(), | |
"pivot_select": pivot_select_sql, | |
"join": join_sql, | |
}, | |
) | |
logging.debug("final query: {}".format(query)) | |
return query |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment