Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
This is an example python class that can be used to create SQL that will pivot data SQL
query-able database (such as Redshift).
Note that it uses an Anomalo library, dquality.db, which is not open-sourced. However,
you can replace the calls to self.db and DB with the appropriate calls for your warehouse
to construct and execute SQL queries.
import logging
import pandas as pd
import dquality.db as DB
class Pivot:
Pivot a table from "long" into "wide" format
For a given table, reduce it to one row per "index" value,
with N columns defined by the distinct values in "column",
each column containing the values in "value".
In addition, will retain the first element of columns in
"first" (along with the "index").
def __init__(self, db, table, indexes, column, value, first=None):
self.db = db
self.table = table
self.column = column
self.value = value
if not isinstance(indexes, list):
self.indexes = [indexes]
self.indexes = indexes
if first is None:
first = []
elif not isinstance(first, list):
first = [first]
self.first = first
def items(self):
sql = """
select distinct {column}
from {table}
query = self.db.make_sql_from_template(
sql, {"table": self.table, "column": self.column}
result = DB.rows_to_pandas(self.db.query(query))
items = result[self.column].values
return items
def names(self):
names = [DB.safe_column_name(v, log=logging).lower() for v in self.items]
counts = pd.Series(names).value_counts()
if any(counts > 1):
dupes = set(counts[counts > 1].index.values)
raise ValueError(
"pivot column names will not be distinct: {}".format(dupes)
return names
def index_sql(self):
return ", ".join(self.indexes)
def base_sql(self):
first_sql = ""
for col in self.first:
sql = """
, first_value({col}) over (
partition by {index_sql}
order by 1
rows between unbounded preceding and current row
) as {col}"""
first_sql += self.db.make_sql_from_template(
sql, {"col": col, "index_sql": self.index_sql}
sql = """\
select distinct {index_sql}\
from {table}"""
query = self.db.make_sql_from_template(
{"table": self.table, "index_sql": self.index_sql, "first_sql": first_sql},
return query
def max_sql(self):
final = []
for item, name in zip(self.items, self.names):
sql = "max(case when {column} = '{item}' then {value} else NULL end) as {name}"
"name": name,
"item": item,
"value": self.value,
"column": self.column,
return ",".join(final)
def pivot_sql(self):
sql = """\
group by
return self.db.make_sql_from_template(
"index_sql": self.index_sql,
"max_sql": self.max_sql(),
"table": self.table,
def sql(self):
# Final query
sql = """
with base as (
), pivot as (
base, pivot
join_sql_list = ["base.{i} = pivot.{i}".format(i=i) for i in self.indexes]
join_sql = " and ".join(join_sql_list)
pivot_select_list = ["pivot.{}".format(n) for n in self.names]
pivot_select_sql = ", ".join(pivot_select_list)
query = self.db.make_sql_from_template(
"base": self.base_sql(),
"pivot": self.pivot_sql(),
"pivot_select": pivot_select_sql,
"join": join_sql,
logging.debug("final query: {}".format(query))
return query
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment