Created
February 7, 2025 01:55
-
-
Save mthorrell/d97a635a251d49dd50476216eac7a416 to your computer and use it in GitHub Desktop.
Comparing LightGBM models using Harmonic CE Loss vs softmax CE Loss fit using GBNet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from tqdm import tqdm | |
import lightgbm as lgb | |
from lightgbm import record_evaluation | |
import torch | |
from torchvision import datasets | |
from gbnet.xgbmodule import XGBModule | |
from gbnet.lgbmodule import LGBModule # we're using LGBModule but could use XGBModule | |
############################################## | |
## LOAD MNIST DATA | |
############################################# | |
# Download the raw MNIST training and test datasets (without any transform) | |
train_dataset = datasets.MNIST(root='./data', train=True) | |
test_dataset = datasets.MNIST(root='./data', train=False) | |
# Convert the tensors to NumPy arrays | |
x_train = train_dataset.data.numpy().reshape([-1, 784]) | |
y_train = train_dataset.targets.numpy() | |
x_test = test_dataset.data.numpy().reshape([-1, 784]) | |
y_test = test_dataset.targets.numpy() | |
############################################## | |
############################################## | |
## TRAIN a LIGHTGBM model using the DistLayer loss | |
############################################## | |
seed = 0 | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
# Copied directly from https://github.com/KindXiaoming/grow-crystals | |
class DistLayer(torch.nn.Linear): | |
def __init__(self, in_features, out_features, n=2., eps=1e-6, bias=False): | |
super(DistLayer, self).__init__(in_features, out_features, bias=bias) | |
self.n = n | |
self.eps = eps | |
def forward(self, x, scale=False): | |
# x: (B, N) | |
# w: (V, N) | |
# dist_sq: (B, V) | |
n_embd = x.size(-1,) | |
w = self.weight | |
#w.data *= 0. | |
wx = torch.einsum('bn,vn->bv', x, w) # (B, V) | |
ww = torch.norm(w, dim=-1)**2 # (V,) | |
xx = torch.norm(x, dim=-1)**2 # (B,) | |
dist_sq = ww[None,:] + xx[:,None] - 2 * wx + self.eps | |
dist_sq = dist_sq / torch.min(dist_sq, dim=-1, keepdim = True)[0] | |
prob = (dist_sq)**(-self.n) | |
return prob/torch.sum(prob, dim=1, keepdim=True) | |
# Rough structure: | |
# 1. LightGBM produces an embedding | |
# 2. That embedding is plugged into DistLayer | |
# 3. After a basic transform, the whole thing is trained via -log(prob) | |
emb_dim = 20 | |
xgb = LGBModule(60_000, 784, emb_dim) | |
distlayer = DistLayer(emb_dim, 10) | |
lr = 1e-2 | |
n_steps = 100 | |
optimizer = torch.optim.Adam( | |
list(distlayer.parameters()) + list(xgb.parameters()), | |
lr = lr | |
) | |
ls = [] | |
err = [] | |
test_err = [] | |
for i in tqdm(range(n_steps)): | |
xgb.train() | |
optimizer.zero_grad() | |
# LGBModule initially outputs all zeros, | |
# so the shrinking random addition is a | |
# small hack to have embeddings start | |
# randomly | |
preds = ( | |
xgb(inputs) | |
+ (1 / (i + 1)) * (torch.rand([60_000, emb_dim]) * 2 - 1) # this line is the hack | |
) | |
probs = distlayer(preds) | |
nloglik = -torch.log(probs) | |
loss = nloglik[range(labels.size(0)), labels].mean() | |
loss.backward(create_graph=True) | |
xgb.gb_step() # update the GBM | |
optimizer.step() # update the DistLayer parameters | |
with torch.no_grad(): | |
err.append( | |
((probs.detach().numpy().argmax(1) - y_train) > 0).mean() | |
) | |
xgb.eval() | |
preds = xgb(x_test) | |
probs = distlayer(preds).detach().numpy() | |
test_err.append( | |
((probs.argmax(1) - y_test) > 0).mean() | |
) | |
###################################### | |
## Train LIGHTGBM model the normal way | |
###################################### | |
# Create LightGBM datasets | |
dtrain = lgb.Dataset(x_train, label=y_train) | |
dtest = lgb.Dataset(x_test, label=y_test, reference=dtrain) | |
# Define parameters | |
params = { | |
'objective': 'multiclass', | |
'num_class': 10, | |
'metric': 'multi_error' | |
} | |
results = {} | |
# Train the model with evaluation | |
bst = lgb.train( | |
params=params, | |
train_set=dtrain, | |
num_boost_round=200, # Give it many more rounds just in case. Emb dimension above indicates | |
# 200 boost rounds (rather than 100) is the right comparison here | |
valid_sets=[dtrain, dtest], | |
valid_names=['train', 'eval'], | |
callbacks=[record_evaluation(results)] | |
) | |
lgbm_error = results['eval']['multi_error'] | |
###################################### | |
## Compare | |
###################################### | |
plt.plot(test_err, label='LGBM with Harmonic Cross-Entropy') | |
plt.plot(np.array(list(range(len(lgbm_error))))/2, lgbm_error, label='LGBM with Standard Cross-Entropy') | |
print(f'Min test error LightGBM with Harmonic Cross-Entropy: {min(test_err)}') | |
print(f'Min test error LightGBM with Standard Cross-Entropy: {min(lgbm_error)}') | |
plt.xlabel('Number of Trees Fit') | |
plt.ylabel('Test Error') | |
plt.legend() | |
plt.title('LightGBM trained with Harmonic vs Standard Cross-Entropy') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment