Skip to content

Instantly share code, notes, and snippets.

@jefferythewind
Last active February 9, 2024 16:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jefferythewind/4fb31496e2445c22fcdfc36b3b0feb04 to your computer and use it in GitHub Desktop.
Save jefferythewind/4fb31496e2445c22fcdfc36b3b0feb04 to your computer and use it in GitHub Desktop.
Mulit-Class Softmax log loss custom objective for LightGBM Classifier
import numpy as np
import lightgbm as lgb
from sklearn.datasets import load_iris
from sklearn.metrics import confusion_matrix
import time
# Load Iris dataset
iris = load_iris()
# Separate features (X) and target (y)
X = iris.data
y = iris.target
# Print the shape of the data
print("Shape of features (X):", X.shape)
print("Shape of target (y):", y.shape)
# Define softmax function
def softmax(x):
e_x = np.exp(x.T - np.max(x, axis=1)).T
return ( e_x.T / e_x.sum(axis=1) ).T
# Custom softmax cross-entropy loss
def softmax_cross_entropy(preds, train_data):
labels = train_data.get_label()
num_class = len(np.unique(labels))
labels = np.eye(num_class)[labels.astype(int)]
preds = softmax(preds)
grad = preds - labels
hess = preds*(1 - preds)
return grad, hess
lgb_train = lgb.Dataset(X, y)
params = {
'max_depth': 5,
'num_leaves': 16,
'n_estimators': 50,
'objective': 'multiclass',
'num_class': len(np.unique(y)),
'verbosity':-1
}
tic = time.time()
model = lgb.train(params, lgb_train)
toc = time.time()
print('Default')
print( confusion_matrix( y, np.argmax( model.predict(X), axis=1 ) ) )
print("Elapsed Time: ", (toc-tic) )
params = {
'max_depth': 5,
'num_leaves': 16,
'n_estimators': 50,
'objective': softmax_cross_entropy,
'num_class': len(np.unique(y)),
'verbosity':-1
}
tic = time.time()
model = lgb.train(params, lgb_train)
toc = time.time()
print('Custom')
print( confusion_matrix( y, np.argmax( model.predict(X), axis=1 ) ) )
print("Elapsed Time: ", (toc-tic) )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment