Skip to content

Instantly share code, notes, and snippets.

@mthorrell
Created February 7, 2025 01:55
Show Gist options
  • Save mthorrell/d97a635a251d49dd50476216eac7a416 to your computer and use it in GitHub Desktop.
Save mthorrell/d97a635a251d49dd50476216eac7a416 to your computer and use it in GitHub Desktop.
Comparing LightGBM models using Harmonic CE Loss vs softmax CE Loss fit using GBNet
import copy
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import lightgbm as lgb
from lightgbm import record_evaluation
import torch
from torchvision import datasets
from gbnet.xgbmodule import XGBModule
from gbnet.lgbmodule import LGBModule # we're using LGBModule but could use XGBModule
##############################################
## LOAD MNIST DATA
#############################################
# Download the raw MNIST training and test datasets (without any transform)
train_dataset = datasets.MNIST(root='./data', train=True)
test_dataset = datasets.MNIST(root='./data', train=False)
# Convert the tensors to NumPy arrays
x_train = train_dataset.data.numpy().reshape([-1, 784])
y_train = train_dataset.targets.numpy()
x_test = test_dataset.data.numpy().reshape([-1, 784])
y_test = test_dataset.targets.numpy()
##############################################
##############################################
## TRAIN a LIGHTGBM model using the DistLayer loss
##############################################
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
# Copied directly from https://github.com/KindXiaoming/grow-crystals
class DistLayer(torch.nn.Linear):
def __init__(self, in_features, out_features, n=2., eps=1e-6, bias=False):
super(DistLayer, self).__init__(in_features, out_features, bias=bias)
self.n = n
self.eps = eps
def forward(self, x, scale=False):
# x: (B, N)
# w: (V, N)
# dist_sq: (B, V)
n_embd = x.size(-1,)
w = self.weight
#w.data *= 0.
wx = torch.einsum('bn,vn->bv', x, w) # (B, V)
ww = torch.norm(w, dim=-1)**2 # (V,)
xx = torch.norm(x, dim=-1)**2 # (B,)
dist_sq = ww[None,:] + xx[:,None] - 2 * wx + self.eps
dist_sq = dist_sq / torch.min(dist_sq, dim=-1, keepdim = True)[0]
prob = (dist_sq)**(-self.n)
return prob/torch.sum(prob, dim=1, keepdim=True)
# Rough structure:
# 1. LightGBM produces an embedding
# 2. That embedding is plugged into DistLayer
# 3. After a basic transform, the whole thing is trained via -log(prob)
emb_dim = 20
xgb = LGBModule(60_000, 784, emb_dim)
distlayer = DistLayer(emb_dim, 10)
lr = 1e-2
n_steps = 100
optimizer = torch.optim.Adam(
list(distlayer.parameters()) + list(xgb.parameters()),
lr = lr
)
ls = []
err = []
test_err = []
for i in tqdm(range(n_steps)):
xgb.train()
optimizer.zero_grad()
# LGBModule initially outputs all zeros,
# so the shrinking random addition is a
# small hack to have embeddings start
# randomly
preds = (
xgb(inputs)
+ (1 / (i + 1)) * (torch.rand([60_000, emb_dim]) * 2 - 1) # this line is the hack
)
probs = distlayer(preds)
nloglik = -torch.log(probs)
loss = nloglik[range(labels.size(0)), labels].mean()
loss.backward(create_graph=True)
xgb.gb_step() # update the GBM
optimizer.step() # update the DistLayer parameters
with torch.no_grad():
err.append(
((probs.detach().numpy().argmax(1) - y_train) > 0).mean()
)
xgb.eval()
preds = xgb(x_test)
probs = distlayer(preds).detach().numpy()
test_err.append(
((probs.argmax(1) - y_test) > 0).mean()
)
######################################
## Train LIGHTGBM model the normal way
######################################
# Create LightGBM datasets
dtrain = lgb.Dataset(x_train, label=y_train)
dtest = lgb.Dataset(x_test, label=y_test, reference=dtrain)
# Define parameters
params = {
'objective': 'multiclass',
'num_class': 10,
'metric': 'multi_error'
}
results = {}
# Train the model with evaluation
bst = lgb.train(
params=params,
train_set=dtrain,
num_boost_round=200, # Give it many more rounds just in case. Emb dimension above indicates
# 200 boost rounds (rather than 100) is the right comparison here
valid_sets=[dtrain, dtest],
valid_names=['train', 'eval'],
callbacks=[record_evaluation(results)]
)
lgbm_error = results['eval']['multi_error']
######################################
## Compare
######################################
plt.plot(test_err, label='LGBM with Harmonic Cross-Entropy')
plt.plot(np.array(list(range(len(lgbm_error))))/2, lgbm_error, label='LGBM with Standard Cross-Entropy')
print(f'Min test error LightGBM with Harmonic Cross-Entropy: {min(test_err)}')
print(f'Min test error LightGBM with Standard Cross-Entropy: {min(lgbm_error)}')
plt.xlabel('Number of Trees Fit')
plt.ylabel('Test Error')
plt.legend()
plt.title('LightGBM trained with Harmonic vs Standard Cross-Entropy')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment