Skip to content

Instantly share code, notes, and snippets.

@jkeefe
Last active June 24, 2019 02:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jkeefe/487b242722d7700d4ffffb4bb62f5979 to your computer and use it in GitHub Desktop.
Save jkeefe/487b242722d7700d4ffffb4bb62f5979 to your computer and use it in GitHub Desktop.
Fast.ai feature importance function for neural nets
# Originally shared by Zachary Mueller here:
# https://forums.fast.ai/t/feature-importance-in-deep-learning/42026/16
# ... which he adapted from Miguel Mota Pinto's post here:
# https://medium.com/@mp.music93/neural-networks-feature-importance-with-fastai-5c393cf65815
# Assumes all necessary fast.ai v1.0 libraries are loaded
def feature_importance(learner):
# based on: https://medium.com/@mp.music93/neural-networks-feature-importance-with-fastai-5c393cf65815
data = learner.data.train_ds.x
cat_names = data.cat_names
cont_names = data.cont_names
loss0=np.array([learner.loss_func(learner.pred_batch(batch=(x,y.to("cpu"))), y.to("cpu")) for x,y in iter(learner.data.valid_dl)]).mean()
fi=dict()
types=[cat_names, cont_names]
for j, t in enumerate(types):
for i, c in enumerate(t):
loss=[]
for x,y in iter(learner.data.valid_dl):
col=x[j][:,i] #x[0] da hier cat-vars
idx = torch.randperm(col.nelement())
x[j][:,i] = col.view(-1)[idx].view(col.size())
y=y.to('cpu')
loss.append(learner.loss_func(learner.pred_batch(batch=(x,y)), y))
fi[c]=np.array(loss).mean()-loss0
d = sorted(fi.items(), key=lambda kv: kv[1], reverse=True)
return pd.DataFrame({'cols': [l for l, v in d], 'imp': np.log1p([v for l, v in d])})
## my model is called 'learn'
features = feature_importance(learn)
## plot 'em!
features.plot('cols', 'imp', 'barh', figsize=(12,15), legend=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment