Skip to content

Instantly share code, notes, and snippets.

@henhuy
Created April 17, 2024 11:39
Show Gist options
  • Save henhuy/6caf28bd5e4414ecb546ef7f059ecb67 to your computer and use it in GitHub Desktop.
Save henhuy/6caf28bd5e4414ecb546ef7f059ecb67 to your computer and use it in GitHub Desktop.
Sankey
import itertools
import pandas as pd
import numpy as np
import plotly.graph_objects as go
RESULTS_FILE = "/industry_scratch.csv"
data = pd.read_csv(RESULTS_FILE, delimiter=";")
labels = (
set(data["process"]) | set(data["input_commodity"]) | set(data["output_commodity"])
)
labels.discard(np.nan)
labels = list(labels)
imports = []
primary = []
secondary = []
others = []
for label in labels:
if "import" in label:
imports.append(label)
elif label.startswith("pri"):
primary.append(label)
elif label.startswith("sec"):
secondary.append(label)
else:
others.append(label)
labels = imports + primary + secondary + others
x = list(
itertools.chain(
itertools.repeat(0.1, len(imports)),
itertools.repeat(0.2, len(primary)),
itertools.repeat(None, len(secondary)),
itertools.repeat(None, len(others)),
)
)
y = (
list(np.linspace(1 / len(imports), 1, len(imports) + 1, endpoint=False))
+ list(np.linspace(1 / len(primary), 1, len(primary) + 1, endpoint=False))
+ list(itertools.repeat(None, len(secondary)))
+ list(itertools.repeat(None, len(others)))
)
source = []
target = []
value = []
label = []
color = []
for _, flow in data.iterrows():
if not isinstance(flow["input_commodity"], str):
source.append(labels.index(flow["process"]))
target.append(labels.index(flow["output_commodity"]))
label.append(flow["output_commodity"])
elif not isinstance(flow["output_commodity"], str):
source.append(labels.index(flow["input_commodity"]))
target.append(labels.index(flow["process"]))
label.append(flow["input_commodity"])
else:
continue
value.append(flow["value"])
fig = go.Figure(
data=[
go.Sankey(
arrangement="fixed",
valueformat=".0f",
valuesuffix="TWh",
# Define nodes
node=dict(
pad=15,
thickness=15,
line=dict(color="black", width=0.5),
label=labels,
x=x,
y=y,
),
# Add links
link=dict(
source=source,
target=target,
value=value,
label=label,
),
)
]
)
fig.update_layout(
title_text="Industriesektor",
font_size=10,
)
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment