Skip to content

Instantly share code, notes, and snippets.

@bbennett36
Last active February 17, 2021 17:28
Show Gist options
  • Save bbennett36/77f0ebca623b3c6751d41a658217c65e to your computer and use it in GitHub Desktop.
Save bbennett36/77f0ebca623b3c6751d41a658217c65e to your computer and use it in GitHub Desktop.
lgb model quickstart
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import lightgbm as lgb
import shap
from sklearn.model_selection import train_test_split
shap.initjs()
# enter input file info here #
input_file = ''
dataset = pd.read_csv(input_file)
X_train, X_test, y_train, y_test = train_test_split(
dataset[[c for c in dataset if 'target' not in c]], dataset['target'], test_size=0.33, random_state=42)
lgb_train = lgb.Dataset(X_train, label=y_train)
lgb_test = lgb.Dataset(X_test, label=y_test, reference=lgb_train)
params = { 'objective':'binary',
'boost_from_average': False,
'learning_rate': 0.05,
'num_leaves': 31,
'lambda_l2': 0.1,
'lambda_l1': 0.1,
'bagging_fraction': 0.8,
'bagging_freq': 1,
'min_child_weight': 2,
'min_split_gain': 0.5,
'feature_fraction': 0.25,
'metric': ['binary_logloss'],
'max_bin': 63
}
params['metric'] = 'binary_error'
num_round = 1000
lgb_model = lgb.train(params, lgb_train, num_round, valid_sets=[lgb_test], early_stopping_rounds=250, verbose_eval=250)
lgb.plot_importance(lgb_model, importance_type='gain', max_num_features=20)
lgb.plot_importance(lgb_model, importance_type='split', max_num_features=20)
features = list(X_train)
X = X_train.apply(pd.to_numeric).values
shap_values = lgb_model.predict(X, pred_contrib=True)
shap.summary_plot(shap_values, X_train.values, X_train.columns, plot_type='violin')
def get_top_feats(shap_values, num_feats):
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
feature_order = list(feature_order[-min(num_feats,len(feature_order)):])
feature_order = feature_order[::-1]
return list(mds['X_test'].iloc[:,feature_order])
top_feats = get_top_feats(shap_values, 40)
print(top_feats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment