Skip to content

Instantly share code, notes, and snippets.

Created May 8, 2019 07:57
Show Gist options
  • Save ken333135/09f8793fff5a6df28558b17e516f91ab to your computer and use it in GitHub Desktop.
Save ken333135/09f8793fff5a6df28558b17e516f91ab to your computer and use it in GitHub Desktop.
Wrapper Function to create Sankey Diagram from DataFrame
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
# maximum of 6 value cols -> 6 colors
colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
labelList = []
colorNumList = []
for catCol in cat_cols:
labelListTemp = list(set(df[catCol].values))
labelList = labelList + labelListTemp
# remove duplicates from labelList
labelList = list(dict.fromkeys(labelList))
# define colors based on number of levels
colorList = []
for idx, colorNum in enumerate(colorNumList):
colorList = colorList + [colorPalette[idx]]*colorNum
# transform df into a source-target pair
for i in range(len(cat_cols)-1):
if i==0:
sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
sourceTargetDf.columns = ['source','target','count']
tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
tempDf.columns = ['source','target','count']
sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
# add index for source-target pair
sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
# creating the sankey diagram
data = dict(
node = dict(
pad = 15,
thickness = 20,
line = dict(
color = "black",
width = 0.5
label = labelList,
color = colorList
link = dict(
source = sourceTargetDf['sourceID'],
target = sourceTargetDf['targetID'],
value = sourceTargetDf['count']
layout = dict(
title = title,
font = dict(
size = 10
fig = dict(data=[data], layout=layout)
return fig
Copy link

Thanks I use it for a diagnostic plot in a modelling pipeline we're building.

Copy link

ken333135 commented Sep 26, 2019 via email

Copy link

jpsteege commented Nov 5, 2020

Nice piece of code and very usable Medium article!
For my own use, I added a code snippet that creates the needed number of colors based on the number of cat_cols using Seaborn color palettes (here: 'Spectral'). For more palette options, check

import pandas as pd
import seaborn as sns

def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # Source:
    # Create colors based on the number of categorical columns
    colorPalette = sns.color_palette("Spectral", len(cat_cols)).as_hex()
    labelList = []
    colorNumList = []


Copy link

rrosasl commented Dec 2, 2020

Hi Ken,

I have used this Sankey so many times :D
Most recently here

There I also have some useful code for converting many DataFrames into the necessary format for the Sankey Diagram :)

Copy link

Thanks for the code!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment