Skip to content

Instantly share code, notes, and snippets.

@ryanorsinger
Created August 17, 2021 15:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryanorsinger/e5b121b410edf17a8e6d9d1de3b28da7 to your computer and use it in GitHub Desktop.
Save ryanorsinger/e5b121b410edf17a8e6d9d1de3b28da7 to your computer and use it in GitHub Desktop.
Explore in a Box (Maggie's explore.py module)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from scipy import stats
def train_validate_test_split(df, target, seed=123):
'''
This function takes in a dataframe, the name of the target variable
(for stratification purposes), and an integer for a setting a seed
and splits the data into train, validate and test.
Test is 20% of the original dataset, validate is .30*.80= 24% of the
original dataset, and train is .70*.80= 56% of the original dataset.
The function returns, in this order, train, validate and test dataframes.
'''
train_validate, test = train_test_split(df, test_size=0.2,
random_state=seed,
stratify=df[target])
train, validate = train_test_split(train_validate, test_size=0.3,
random_state=seed,
stratify=train_validate[target])
return train, validate, test
def explore_univariate(train, cat_vars, quant_vars):
for var in cat_vars:
explore_univariate_categorical(train, var)
print('_________________________________________________________________')
for col in quant_vars:
p, descriptive_stats = explore_univariate_quant(train, col)
plt.show(p)
print(descriptive_stats)
def explore_bivariate(train, target, cat_vars, quant_vars):
for cat in cat_vars:
explore_bivariate_categorical(train, target, cat)
for quant in quant_vars:
explore_bivariate_quant(train, target, quant)
def explore_multivariate(train, target, cat_vars, quant_vars):
'''
'''
plot_swarm_grid_with_color(train, target, cat_vars, quant_vars)
plt.show()
violin = plot_violin_grid_with_color(train, target, cat_vars, quant_vars)
plt.show()
pair = sns.pairplot(data=train, vars=quant_vars, hue=target)
plt.show()
plot_all_continuous_vars(train, target, quant_vars)
plt.show()
### Univariate
def explore_univariate_categorical(train, cat_var):
'''
takes in a dataframe and a categorical variable and returns
a frequency table and barplot of the frequencies.
'''
frequency_table = freq_table(train, cat_var)
plt.figure(figsize=(2,2))
sns.barplot(x=cat_var, y='Count', data=frequency_table, color='lightseagreen')
plt.title(cat_var)
plt.show()
print(frequency_table)
def explore_univariate_quant(train, quant_var):
'''
takes in a dataframe and a quantitative variable and returns
descriptive stats table, histogram, and boxplot of the distributions.
'''
descriptive_stats = train[quant_var].describe()
plt.figure(figsize=(8,2))
p = plt.subplot(1, 2, 1)
p = plt.hist(train[quant_var], color='lightseagreen')
p = plt.title(quant_var)
# second plot: box plot
p = plt.subplot(1, 2, 2)
p = plt.boxplot(train[quant_var])
p = plt.title(quant_var)
return p, descriptive_stats
def freq_table(train, cat_var):
'''
for a given categorical variable, compute the frequency count and percent split
and return a dataframe of those values along with the different classes.
'''
class_labels = list(train[cat_var].unique())
frequency_table = (
pd.DataFrame({cat_var: class_labels,
'Count': train[cat_var].value_counts(normalize=False),
'Percent': round(train[cat_var].value_counts(normalize=True)*100,2)}
)
)
return frequency_table
#### Bivariate
def explore_bivariate_categorical(train, target, cat_var):
'''
takes in categorical variable and binary target variable,
returns a crosstab of frequencies
runs a chi-square test for the proportions
and creates a barplot, adding a horizontal line of the overall rate of the target.
'''
print(cat_var, "\n_____________________\n")
ct = pd.crosstab(train[cat_var], train[target], margins=True)
chi2_summary, observed, expected = run_chi2(train, cat_var, target)
p = plot_cat_by_target(train, target, cat_var)
print(chi2_summary)
print("\nobserved:\n", ct)
print("\nexpected:\n", expected)
plt.show(p)
print("\n_____________________\n")
def explore_bivariate_quant(train, target, quant_var):
'''
descriptive stats by each target class.
compare means across 2 target groups
boxenplot of target x quant
swarmplot of target x quant
'''
print(quant_var, "\n____________________\n")
descriptive_stats = train.groupby(target)[quant_var].describe()
average = train[quant_var].mean()
mann_whitney = compare_means(train, target, quant_var)
plt.figure(figsize=(4,4))
boxen = plot_boxen(train, target, quant_var)
swarm = plot_swarm(train, target, quant_var)
plt.show()
print(descriptive_stats, "\n")
print("\nMann-Whitney Test:\n", mann_whitney)
print("\n____________________\n")
## Bivariate Categorical
def run_chi2(train, cat_var, target):
observed = pd.crosstab(train[cat_var], train[target])
chi2, p, degf, expected = stats.chi2_contingency(observed)
chi2_summary = pd.DataFrame({'chi2': [chi2], 'p-value': [p],
'degrees of freedom': [degf]})
expected = pd.DataFrame(expected)
return chi2_summary, observed, expected
def plot_cat_by_target(train, target, cat_var):
p = plt.figure(figsize=(2,2))
p = sns.barplot(cat_var, target, data=train, alpha=.8, color='lightseagreen')
overall_rate = train[target].mean()
p = plt.axhline(overall_rate, ls='--', color='gray')
return p
## Bivariate Quant
def plot_swarm(train, target, quant_var):
average = train[quant_var].mean()
p = sns.swarmplot(data=train, x=target, y=quant_var, color='lightgray')
p = plt.title(quant_var)
p = plt.axhline(average, ls='--', color='black')
return p
def plot_boxen(train, target, quant_var):
average = train[quant_var].mean()
p = sns.boxenplot(data=train, x=target, y=quant_var, color='lightseagreen')
p = plt.title(quant_var)
p = plt.axhline(average, ls='--', color='black')
return p
# alt_hyp = ‘two-sided’, ‘less’, ‘greater’
def compare_means(train, target, quant_var, alt_hyp='two-sided'):
x = train[train[target]==0][quant_var]
y = train[train[target]==1][quant_var]
return stats.mannwhitneyu(x, y, use_continuity=True, alternative=alt_hyp)
### Multivariate
def plot_all_continuous_vars(train, target, quant_vars):
'''
Melt the dataset to "long-form" representation
boxenplot of measurement x value with color representing the target variable.
'''
my_vars = [item for sublist in [quant_vars, [target]] for item in sublist]
sns.set(style="whitegrid", palette="muted")
melt = train[my_vars].melt(id_vars=target, var_name="measurement")
plt.figure(figsize=(8,6))
p = sns.boxenplot(x="measurement", y="value", hue=target, data=melt)
p.set(yscale="log", xlabel='')
plt.show()
def plot_violin_grid_with_color(train, target, cat_vars, quant_vars):
cols = len(cat_vars)
for quant in quant_vars:
_, ax = plt.subplots(nrows=1, ncols=cols, figsize=(16, 4), sharey=True)
for i, cat in enumerate(cat_vars):
sns.violinplot(x=cat, y=quant, data=train, split=True,
ax=ax[i], hue=target, palette="Set2")
ax[i].set_xlabel('')
ax[i].set_ylabel(quant)
ax[i].set_title(cat)
plt.show()
def plot_swarm_grid_with_color(train, target, cat_vars, quant_vars):
cols = len(cat_vars)
for quant in quant_vars:
_, ax = plt.subplots(nrows=1, ncols=cols, figsize=(16, 4), sharey=True)
for i, cat in enumerate(cat_vars):
sns.swarmplot(x=cat, y=quant, data=train, ax=ax[i], hue=target, palette="Set2")
ax[i].set_xlabel('')
ax[i].set_ylabel(quant)
ax[i].set_title(cat)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment