Skip to content

Instantly share code, notes, and snippets.

@zhangyuan
Created August 2, 2021 08:28
Show Gist options
  • Save zhangyuan/5a1e9243b9b6f03b9c13e456eb26d861 to your computer and use it in GitHub Desktop.
Save zhangyuan/5a1e9243b9b6f03b9c13e456eb26d861 to your computer and use it in GitHub Desktop.
Render sankey diagram with plotly
import pandas as pd
import plotly.graph_objects as go
def sankey(df, columns, measurement, title='Sankey Diagram'):
labels = set()
for column in columns:
labels = labels.union(set(df[column]))
labels = list(labels)
target_columns = []
if len(columns) == 2:
target_columns.append(columns)
else:
for start in range(0, len(columns) - 1):
target_columns.append(columns[start:start+2])
source = []
target = []
value = []
for groupby in target_columns:
dfx = df.groupby(groupby)[measurement].nunique().reset_index(level=[0, 1])
source.extend(list(labels.index(x) for x in dfx[groupby[0]]))
target.extend(list(labels.index(x) for x in dfx[groupby[1]]))
value.extend(list(dfx[measurement]))
fig = go.Figure(data=[go.Sankey(
node = dict(
pad = 15,
thickness = 10,
line = dict(color = "black", width = 1),
label = labels,
color = "blue"
),
link = dict(
source = source,
target = target,
value = value
))])
fig.update_layout(title_text=title, font_size=10, width=1000, height=600)
fig.show()
# Usage
sankey(df, ['Region', 'Role'], 'User')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment