Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Created January 28, 2020 13:46
Show Gist options
  • Save crcrpar/5508a11a0470a994703f5bc16570f368 to your computer and use it in GitHub Desktop.
Save crcrpar/5508a11a0470a994703f5bc16570f368 to your computer and use it in GitHub Desktop.
# Copied from https://blog.amedama.jp/entry/lightgbm-custom-metric (written in Japanese)
from lightgbm import callback
import lightgbm as lgb
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
def accuracy(preds, data):
y_true = data.get_label()
N_LABELS = 3
reshaped_preds = preds.reshape(N_LABELS, len(preds) // N_LABELS)
y_pred = np.argmax(reshaped_preds, axis=0)
acc = np.mean(y_true == y_pred)
return 'accuracy', acc, True
def log_evaluation(period=1, show_stdv=True):
def _callback(env):
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join(
[callback._format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
print(env.evaluation_result_list[0])
_callback.order = 10
return _callback
def main():
iris = datasets.load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y,
shuffle=True,
random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
lgbm_params = {
'objective': 'multiclass',
'num_class': 3,
}
evals_result = {}
lgb.train(lgbm_params,
lgb_train,
valid_sets=[lgb_eval, lgb_train],
valid_names=['eval', 'train'],
num_boost_round=1000,
evals_result=evals_result,
feval=accuracy,
callbacks=[log_evaluation()]
)
eval_metric_logloss = evals_result['eval']['multi_logloss']
train_metric_logloss = evals_result['train']['multi_logloss']
eval_metric_acc = evals_result['eval']['accuracy']
train_metric_acc = evals_result['train']['accuracy']
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment