Skip to content

Instantly share code, notes, and snippets.

@rfriedman22
Last active August 30, 2019 04:40
Show Gist options
  • Save rfriedman22/a19381272c4cc970f170662e9cf2dd2d to your computer and use it in GitHub Desktop.
Save rfriedman22/a19381272c4cc970f170662e9cf2dd2d to your computer and use it in GitHub Desktop.
Example code for performing hierarchical clustering and plotting the result as a heatmap with dendrograms and class labels.
#!/usr/bin/env python3
"""
Example code for plotting a heatmap, colorbar, dendrograms, and class labels with matplotlib.
Author: Ryan Z. Friedman (rfriedman22)
Email: ryanfriedman22@gmail.com
License: BSD 3 clause
Package versions:
Python 3.6.5
Matplotlib 3.0.2
Numpy 1.15.4
Pandas 0.23.4
Scipy 1.1.0
Scikit-learn 0.19.1
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import pandas as pd
from scipy import stats
from scipy.cluster import hierarchy
from sklearn import datasets
from sklearn.preprocessing import MinMaxScaler
# Set display parameters
mpl.rcParams["axes.titlesize"] = 25
mpl.rcParams["axes.labelsize"] = 20
mpl.rcParams["xtick.labelsize"] = 15
mpl.rcParams["ytick.labelsize"] = 15
mpl.rcParams["legend.fontsize"] = 15
mpl.rcParams["figure.figsize"] = (8, 8)
mpl.rcParams["image.cmap"] = "viridis"
mpl.rcParams["lines.markersize"] = 5
mpl.rcParams["lines.linewidth"] = 3
# Load the wine dataset
wine_data = datasets.load_wine()
data = wine_data["data"]
labels = wine_data["target"]
feature_names = wine_data["feature_names"]
# Min-max normalize the data
data = MinMaxScaler().fit_transform(data)
# Make the data pandas objects
data = pd.DataFrame(data, columns=feature_names)
labels = pd.Series(labels)
# Row-cluster the data
row_link = hierarchy.linkage(data, method="ward")
# Make the dendrogram, but don't plot it yet
row_dendro = hierarchy.dendrogram(row_link, no_plot=True)
# Row order needs to be inverted to display correctly with the dendrogram
row_order = row_dendro["leaves"][::-1]
# Column-cluster the data but don't plot the dendrogram yet
col_link = hierarchy.linkage(data.T, method="ward")
col_dendro = hierarchy.dendrogram(col_link, no_plot=True)
# Col order does not need to be inverted
col_order = col_dendro["leaves"]
# Make the heatmap
fig, ax = plt.subplots()
data = data.iloc[row_order, col_order]
labels = labels[row_order]
heatmap = ax.imshow(data, aspect="auto")
ax.set_yticks([])
# Show the feature names on the x ticks
ax.set_xticks(np.arange(data.columns.size))
ax.set_xticklabels(data.columns, rotation=90)
# Add heatmap colorbar to the right
divider = make_axes_locatable(ax)
cbar_ax = divider.append_axes("right", size="5%", pad="2%")
fig.colorbar(heatmap, cax=cbar_ax)
# Add bar indicating classes to the left
class_ax = divider.append_axes("left", size="5%")
# Need to make the series into a column vector to work with imshow
# Use a different color scheme for this axis
class_ax.imshow(labels[:, np.newaxis], aspect="auto", cmap="Set1")
class_ax.set_xticks([])
class_ax.set_yticks([])
# Add dendrogram for row clustering
row_ax = divider.append_axes("left", size="50%")
hierarchy.dendrogram(row_link, no_labels=True, orientation="left", ax=row_ax, color_threshold=0, above_threshold_color="black")
row_ax.axis("off")
# Add dendrogram for column clustering
col_ax = divider.append_axes("top", size="25%")
hierarchy.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black")
col_ax.axis("off")
fig.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment