Skip to content

Instantly share code, notes, and snippets.

@ytsaig
Created February 21, 2017 12:54
Show Gist options
  • Save ytsaig/f9a7963e62b3451004a8c1b1a3297950 to your computer and use it in GitHub Desktop.
Save ytsaig/f9a7963e62b3451004a8c1b1a3297950 to your computer and use it in GitHub Desktop.
Reproducible example for custom subsampling with LightGBM
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
def custom_subsample(n, frac):
"""Subsample frac*n indices."""
return np.random.choice(n, int(n*frac), replace=False)
# load data
print('Load data...')
data_dir = '..'
df_train = pd.read_csv('{}/regression/regression.train'.format(data_dir), header=None, sep='\t')
df_test = pd.read_csv('{}/regression/regression.test'.format(data_dir), header=None, sep='\t')
y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)
"""
Built-in subsampling
"""
params = dict(
boosting_type="gbdt",
num_leaves=31,
max_depth=-1,
learning_rate=0.01,
max_bin=255,
objective="regression",
subsample=0.25,
subsample_freq=1,
)
# Create dataset
train_set = lgb.Dataset(X_train, label=y_train, params=params)
# Create Booster object
booster = lgb.Booster(params=params, train_set=train_set)
# Train
for _ in range(100):
booster.update()
y_pred_builtin = booster.predict(X_test)
print('The rmse of built-in prediction is:', mean_squared_error(y_test, y_pred_builtin) ** 0.5)
"""
Custom subsampling
"""
params = dict(
boosting_type="gbdt",
num_leaves=31,
max_depth=-1,
learning_rate=0.01,
max_bin=255,
objective="regression",
subsample=1.,
subsample_freq=0,
)
# Create dataset
train_set = lgb.Dataset(X_train, label=y_train, params=params)
# Create Booster object
booster = lgb.Booster(params=params, train_set=train_set)
# Train
ts = []
for _ in range(100):
subsample = custom_subsample(X_train.shape[0], frac=0.25)
ts.append(train_set.subset(subsample))
booster.update(ts[-1])
y_pred_custom = booster.predict(X_test)
print('The rmse of custom prediction is:', mean_squared_error(y_test, y_pred_custom) ** 0.5)
print('The rmse between the two predictions is:', mean_squared_error(y_pred_builtin, y_pred_custom) ** 0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment