Skip to content

Instantly share code, notes, and snippets.

@YuanfengZhang
Last active January 18, 2024 09:04
Show Gist options
  • Save YuanfengZhang/69205779257d34a1f6698a79339d086a to your computer and use it in GitHub Desktop.
Save YuanfengZhang/69205779257d34a1f6698a79339d086a to your computer and use it in GitHub Desktop.
Draw a heatmap for discrete values in Python
# -*- coding: utf-8 -*-
"""
A quick guide for beginners to draw a heatmap for discrete values.
"""
from collections import OrderedDict
from typing import Tuple
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.axes import Axes
from matplotlib.figure import Figure
def discrete_heatmap(df: pd.DataFrame,
x_column: str,
y_column: str,
value_column: str,
colormap: dict[str, str],
string_for_nan: str = 'NA') -> Tuple[Figure, Axes]:
"""Pivot the original dataframe, create int dataframe, draw the heatmap, add the annotation.
Args:
df (pd.DataFrame): the pandas dataframe containing the data u gonna plot.
The Df should look like this:
>>> df
country month most_liked
U.S. Dec Andy
Russian Dec Sasha
China Dec Quan
... ... ...
x_column (str): the column which u would like to show in the heatmap as x-axis.
y_column (str): the column which u would like to show in the heatmap as y-axis.
value_column (str): the column which will be used to fill the cells in heatmap.
colormap (dict[str, str]): all the unique discrete values and colors u wanna use in
the heatmap. The cmap should look like this:
>>> colormap
{'Andy': '#BBD844', 'Sasha': '#448CA2', 'Quan': '#E77FCA'}
If the nan values are not included in colormap, {'NA': '#DAD8D8'} will be used as default.
"""
# Pivot
_pivot: pd.DataFrame
try:
_pivot = df.pivot(index=y_column,
columns=x_column,
values=value_column).fillna(string_for_nan)
except ValueError:
_pivot = df.pivot_table(index=y_column,
columns=x_column,
values=value_column,
aggfunc='first').fillna(string_for_nan)
# Plot
_ordered_cmap: OrderedDict[str, str]
_index_dict: dict[str, int]
_fig: Figure
_ax: Axes
if 'NA' not in colormap.keys():
_ordered_cmap = OrderedDict({**{string_for_nan: '#DAD8D8'},
**colormap})
else:
_ordered_cmap = OrderedDict(colormap)
_index_dict = {_k: _i for _i, _k in enumerate(_ordered_cmap)}
_fig = plt.figure(figsize=(8, 8), dpi=200)
_ax = sns.heatmap(data=_pivot,
cmap=LinearSegmentedColormap.from_list(name=value_column,
colors=list(_ordered_cmap.values()),
N=len(_ordered_cmap)),
annot=_pivot.map(func=lambda x: _index_dict[x]),
fmt='', linewidths=.5)
_colorbar = _ax.collections[0].colorbar
_colorbar.set_ticklabels(list(_ordered_cmap.keys()))
_ax.tick_params(axis='both', length=0)
return _fig, _ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment