Last active
May 30, 2019 17:53
-
-
Save spikar/cbaab00560761f9cef27d7cdfe9ddf28 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
from scipy import stats | |
from sklearn import preprocessing | |
from sklearn.ensemble import ExtraTreesClassifier | |
def variable_selection(df, target, variance_thres, pbs_thres, chi_sqr_thres, feat_imp_thres): | |
def normalize(df): | |
df = df.abs() | |
# Create x, where x the 'scores' column's values as floats | |
x = df.values.astype(float) | |
# Create a minimum and maximum processor object | |
min_max_scaler = preprocessing.MinMaxScaler() | |
# Create an object to transform the data to fit minmax processor | |
x_scaled = min_max_scaler.fit_transform(x) | |
# Run the normalizer on the dataframe | |
df_normalized = pd.DataFrame(x_scaled) | |
return df_normalized | |
def categories(series): | |
return range(int(series.min()), int(series.max()) + 1) | |
def chi_square_df(df, col1, col2): | |
df_col1, df_col2 = df[col1], df[col2] | |
cats1, cats2 = categories(df_col1), categories(df_col2) | |
def aux(is_cat1): | |
return [sum(is_cat1 & (df_col2 == cat2)) | |
for cat2 in cats2] | |
result = [aux(df_col1 == cat1) | |
for cat1 in cats1] | |
return stats.chi2_contingency(result) | |
# Selecting columns above certain variance threshold | |
df_var = df.var() | |
df_var = df_var.drop(target) | |
df_var = normalize(df_var) | |
high_var_cols = list(df_var[df_var > variance_thres].index) | |
high_var_cols.append(target) | |
df = df[high_var_cols] | |
if len(df.columns) <= 2: | |
print('One or less independent variables remaining after variance filter step') | |
return df | |
# If target variable is binary, calculate the point biserial coefficient for continuous variables and the p-value of the chi-square test | |
classes = len(np.unique(df[target])) | |
if classes == 2: | |
cat_var = [] | |
cont_var = [] | |
for i in df.columns: | |
if len(np.unique(df[i])) <= 8: | |
cat_var.append(i) | |
else: | |
cont_var.append(i) | |
cat_var.remove(target) | |
pbs_coeff = [] | |
for i in cont_var: | |
pbs_coeff.append(stats.pointbiserialr(df[i], df[target])[0]) | |
cont_coeff_df = pd.DataFrame({'variable': cont_var, 'coeff': pbs_coeff}) | |
cont_coeff_df = normalize(cont_coeff_df) | |
chi_sqr_coeff = [] | |
for i in cat_var: | |
chi_sqr_coeff.append(chi_square_df(df, i, target)[0]) | |
cat_coeff_df = pd.DataFrame({'variable': cat_var, 'coeff': chi_sqr_coeff}) | |
cat_coeff_df = normalize(cat_coeff_df) | |
cont_coeff_df = cont_coeff_df[cont_coeff_df['coeff'] > pbs_thres] | |
cat_coeff_df = cat_coeff_df[cat_coeff_df['coeff'] > chi_sqr_thres] | |
final_var = list(pd.concat([cont_coeff_df, cat_coeff_df])['variable']) | |
final_var.append(target) | |
df = df[final_var] | |
if len(df.columns) <= 2: | |
print('One or less independent variables remaining after correlation filter step') | |
return df | |
# Further filter variables using the feature importance attribute of extratreeclassifier model | |
# Note that here we can use other models like random forest, decision tree that have the feature importance attribute | |
X = df.iloc[:, 0:-1].values | |
y = df.iloc[:, -1].values | |
model = ExtraTreesClassifier() | |
model.fit(X, y) | |
feat_imp = model.feature_importances_ | |
columns = list(df.columns) | |
columns.remove(target) | |
feat_imp_df = pd.DataFrame({'variable': columns, 'feat_imp': feat_imp}) | |
imp_feat = list(feat_imp_df[feat_imp_df['feat_imp'] > feat_imp_thres]['variable']) | |
imp_feat.append(target) | |
df = df[imp_feat] | |
return df | |
# The function can be called to get a new dataframe with the selected columns | |
# Note that the variance, correlation and feature importance values have been normalized between 0 and 1 | |
# So the threshold values should be between 0 and 1 | |
df_filtered = variable_selection(df, target, variance_thres, pbs_thres, chi_sqr_thres, feat_imp_thres) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment