Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
"""
This is an example python class that can be used to create SQL that will pivot data SQL
query-able database (such as Redshift).
Note that it uses an Anomalo library, dquality.db, which is not open-sourced. However,
you can replace the calls to self.db and DB with the appropriate calls for your warehouse
to construct and execute SQL queries.
"""
import logging
import pandas as pd
import dquality.db as DB
class Pivot:
"""
Pivot a table from "long" into "wide" format
For a given table, reduce it to one row per "index" value,
with N columns defined by the distinct values in "column",
each column containing the values in "value".
In addition, will retain the first element of columns in
"first" (along with the "index").
"""
def __init__(self, db, table, indexes, column, value, first=None):
self.db = db
self.table = table
self.column = column
self.value = value
if not isinstance(indexes, list):
self.indexes = [indexes]
else:
self.indexes = indexes
if first is None:
first = []
elif not isinstance(first, list):
first = [first]
self.first = first
@property
def items(self):
sql = """
select distinct {column}
from {table}
"""
query = self.db.make_sql_from_template(
sql, {"table": self.table, "column": self.column}
)
result = DB.rows_to_pandas(self.db.query(query))
items = result[self.column].values
return items
@property
def names(self):
names = [DB.safe_column_name(v, log=logging).lower() for v in self.items]
counts = pd.Series(names).value_counts()
if any(counts > 1):
dupes = set(counts[counts > 1].index.values)
raise ValueError(
"pivot column names will not be distinct: {}".format(dupes)
)
return names
@property
def index_sql(self):
return ", ".join(self.indexes)
def base_sql(self):
first_sql = ""
for col in self.first:
sql = """
, first_value({col}) over (
partition by {index_sql}
order by 1
rows between unbounded preceding and current row
) as {col}"""
first_sql += self.db.make_sql_from_template(
sql, {"col": col, "index_sql": self.index_sql}
)
sql = """\
select distinct {index_sql}\
{first_sql}
from {table}"""
query = self.db.make_sql_from_template(
sql,
{"table": self.table, "index_sql": self.index_sql, "first_sql": first_sql},
)
return query
def max_sql(self):
final = []
for item, name in zip(self.items, self.names):
sql = "max(case when {column} = '{item}' then {value} else NULL end) as {name}"
final.append(
self.db.make_sql_from_template(
sql,
{
"name": name,
"item": item,
"value": self.value,
"column": self.column,
},
)
)
return ",".join(final)
def pivot_sql(self):
sql = """\
select
{index_sql},
{max_sql}
from
{table}
group by
{index_sql}
"""
return self.db.make_sql_from_template(
sql,
{
"index_sql": self.index_sql,
"max_sql": self.max_sql(),
"table": self.table,
},
)
@property
def sql(self):
# Final query
sql = """
with base as (
{base}
), pivot as (
{pivot}
)
select
base.*,
{pivot_select}
from
base, pivot
where
{join}
"""
join_sql_list = ["base.{i} = pivot.{i}".format(i=i) for i in self.indexes]
join_sql = " and ".join(join_sql_list)
pivot_select_list = ["pivot.{}".format(n) for n in self.names]
pivot_select_sql = ", ".join(pivot_select_list)
query = self.db.make_sql_from_template(
sql,
{
"base": self.base_sql(),
"pivot": self.pivot_sql(),
"pivot_select": pivot_select_sql,
"join": join_sql,
},
)
logging.debug("final query: {}".format(query))
return query
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment