Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
remove redundant columns in pandas dataframe
import pandas as pd
import numpy as np
def find_correlation(data, threshold=0.9, remove_negative=False):
"""
Given a numeric pd.DataFrame, this will find highly correlated features,
and return a list of features to remove.
Parameters
-----------
data : pandas DataFrame
DataFrame
threshold : float
correlation threshold, will remove one of pairs of features with a
correlation greater than this value.
remove_negative: Boolean
If true then features which are highly negatively correlated will
also be returned for removal.
Returns
--------
select_flat : list
listof column names to be removed
"""
corr_mat = data.corr()
if remove_negative:
corr_mat = np.abs(corr_mat)
corr_mat.loc[:, :] = np.tril(corr_mat, k=-1)
already_in = set()
result = []
for col in corr_mat:
perfect_corr = corr_mat[col][corr_mat[col] > threshold].index.tolist()
if perfect_corr and col not in already_in:
already_in.update(set(perfect_corr))
perfect_corr.append(col)
result.append(perfect_corr)
select_nested = [f[1:] for f in result]
select_flat = [i for j in select_nested for i in j]
return select_flat
@ryancheunggit
Copy link

ryancheunggit commented Jun 23, 2018

Hi, this is pretty handy. Maybe you should also consider removing the items that are perfectly negatively correlated.

@Swarchal
Copy link
Author

Swarchal commented Jun 30, 2018

@ryancheunggit great point, I've included that feature (not tested it).

@DipakDA
Copy link

DipakDA commented Nov 17, 2018

Just a small thing that I noticed. This function returns a list of attributes that need to be removed but there are duplicates in this list(On the dataset which I am working on). Why are there duplicates? Also, what should be the best way to deal with them?

@elvinaqa
Copy link

elvinaqa commented Sep 1, 2020

What is flaw of this?


upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))

to_drop = [column for column in upper.columns if any(upper[column] > 0.95)]```

@ziqueiros
Copy link

ziqueiros commented Jan 30, 2022

I think this code has a severe bug, just take the sample from elvinaqa on this same list of comments is working better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment