Skip to content

Instantly share code, notes, and snippets.

@Ailuropoda1864
Created September 8, 2017 06:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ailuropoda1864/e2e62b19b6a0fd75048fa0c50eda6b69 to your computer and use it in GitHub Desktop.
Save Ailuropoda1864/e2e62b19b6a0fd75048fa0c50eda6b69 to your computer and use it in GitHub Desktop.
A wrapper function for seaborn.heatmap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial
def corr_heatmap(dataframe, cmap=sns.diverging_palette(220, 10, as_cmap=True),
**kwargs):
"""
:param dataframe: a pandas DataFrame
:param cmap: color map to feed into sns.heatmap
:param kwargs: kwargs to feed into sns.heatmap
:return: fig, ax
"""
corr = dataframe.corr()
# Generate a mask for the upper triangle
mask = np.zeros_like(corr, dtype=np.bool)
# np.triu_indices_from returns the indices for the upper-triangle of mask
mask[np.triu_indices_from(mask)] = True
# Set up the matplotlib figure
fig, ax = plt.subplots(figsize=(11, 9))
# Draw the heatmap with the mask and correct aspect ratio
masked_heatmap = partial(sns.heatmap, mask=mask, cmap=cmap, vmax=0.3,
center=0, linewidths=0.5, square=True,
cbar_kws={"shrink": 0.5})
ax = masked_heatmap(corr, **kwargs)
return fig, ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment