Skip to content

Instantly share code, notes, and snippets.

@jtrive84
Created January 28, 2020 20:50
Show Gist options
  • Save jtrive84/954732b91bd697f19edc0599094ce647 to your computer and use it in GitHub Desktop.
Save jtrive84/954732b91bd697f19edc0599094ce647 to your computer and use it in GitHub Desktop.
Assess categorical association between nominal predictors.
import datetime
import itertools
import os
import os.path
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
from matplotlib.backends.backend_pdf import PdfPages
mpl.rcParams["axes.edgecolor"] = "black"
mpl.rcParams["axes.linewidth"] = .5
mpl.rcParams["axes.titlesize"] = 7
mpl.rcParams['axes.facecolor'] = "#E8E8E8"
mpl.rcParams['figure.facecolor'] = "#FFFFFF"
def variable_association(df, catvars, outdir, cmap="cividis"):
"""
Generate correlation grid for categorical variables.
"""
dfcorr = _get_corr_grid(df=df, catvars=catvars)
x_labs = dfcorr.columns.values.tolist()
y_labs = dfcorr.columns.values.tolist()
fig, ax = plt.subplots(1, 1, tight_layout=False)
im = _heatmap(
data=dfcorr, row_labels=x_labs, col_labels=y_labs,
ax=ax, cmap=cmap
)
# Annotate correlation grid units.
text = _annotate_heatmap(im=im, fontsize=7, valfmt="{x:.2f}")
# Export dataset to csv.
exhibit_name = "Categorical_Correlation_Matrix"
csv_path = os.path.join(
outdir, exhibit_name + ".csv"
)
png_path = os.path.join(
outdir, exhibit_name + ".png"
)
pdf_path = os.path.join(
outdir, exhibit_name + ".pdf"
)
dfcorr.to_csv(csv_path)
fig.savefig(pdf_path, format="pdf")
fig.savefig(png_path, format="png")
plt.close(fig=fig)
return(pdf_path)
def _cramersv(marr):
"""
Calculate Cramers V statistic for categorial-categorial association.
Uses correction from Bergsma and Wicher,
Journal of the Korean Statistical Society 42 (2013): 323-328.
"""
r, k = marr.shape
chi2 = stats.chi2_contingency(marr)[0]
n = marr.sum()
phi2 = chi2 / n
phi2corr = max(0, phi2 - ((k - 1) * (r - 1)) / (n - 1))
rcorr, kcorr = r - ((r - 1)**2) / (n - 1), k - ((k - 1)**2) / (n - 1)
return(np.sqrt(phi2corr / np.min([(kcorr - 1), (rcorr - 1)])))
def _get_corr_grid(df, catvars):
"""
Generate correlation DataFrame for categorical variables.
"""
varpairs = list(
itertools.combinations(catvars, 2)
)
dfcorr = pd.DataFrame(
columns=catvars, index=catvars
)
for row_name, col_name in varpairs:
dfxtab = pd.crosstab(df[row_name], df[col_name])
dfxtab.columns.name, dfxtab.index.name = None, None
var_pair_corr = _cramersv(dfxtab.values)
dfcorr.at[col_name, row_name] = var_pair_corr
dfcorr.at[row_name, col_name] = var_pair_corr
return(dfcorr.fillna(1))
def _heatmap(data, row_labels, col_labels, ax, cmap, **kwargs):
"""
Create a heatmap from a numpy array and two lists of labels.
Parameters
----------
data
A 2D numpy array of shape (N, M).
row_labels
A list or array of length N with the labels for the rows.
col_labels
A list or array of length M with the labels for the columns.
ax
Matplotlib axes object.
cmap
Color map.
**kwargs
All other arguments are forwarded to `imshow`.
"""
data = np.asarray(data, dtype=np.float32)
im = ax.imshow(data, cmap=cmap, **kwargs)
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
ax.set_xticklabels(col_labels, color="#dc143c")
ax.set_yticklabels(row_labels, color="#dc143c")
# Let the horizontal axes labeling appear on top.
ax.tick_params(
axis="both", left=True, top=True, right=False, bottom=False,
labelleft=True, labeltop=True, labelright=False, labelbottom=False,
labelsize=6, color="#000000",
)
# Rotate the tick labels and set their alignment.
plt.setp(
ax.get_xticklabels(), rotation=-90, ha="center", rotation_mode="default"
)
return(im)
def _annotate_heatmap(im, valfmt="{x:.2f}", **kwargs):
"""
A function to annotate a heatmap.
Parameters
----------
im
The AxesImage to be labeled.
valfmt
The format of the annotations inside the heatmap. This should either
use the string format method, e.g. "$ {x:.2f}", or be a
`matplotlib.ticker.Formatter`. Optional.
**kwargs
All other arguments are forwarded to each call to `text` used to create
the text labels.
"""
textcolors=["white", "black"]
data = im.get_array()
thresh = im.norm(data.max()) / 2.
kw = dict(ha="center", va="center")
kw.update(kwargs)
valfmt = mpl.ticker.StrMethodFormatter(valfmt)
textlist = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(color=textcolors[int(im.norm(data[i, j]) > thresh)])
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
textlist.append(text)
return(textlist)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment