Skip to content

Instantly share code, notes, and snippets.

@szilard
Last active August 1, 2023 00:59
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save szilard/8ede61d1b6869ff96b9d1d87851a2536 to your computer and use it in GitHub Desktop.
Save szilard/8ede61d1b6869ff96b9d1d87851a2536 to your computer and use it in GitHub Desktop.
GBM vs SV-DKL (Stochastic Variational Deep Kernel Learning) on the airline dataset
## Stochastic Variational Deep Kernel Learning
## paper: https://arxiv.org/abs/1611.00336
## code+data from the authors (thanks!!!): https://people.orie.cornell.edu/andrew/code/#SVDKL
## get data + prepare sample authors used for evaluation
wget https://people.orie.cornell.edu/andrew/code/svdklcode.zip
unzip svdklcode.zip
cd caffe/examples/airline/data/
./get_airline.sh
python prep_airline.py
cd -
R:
library(data.table)
d <- fread("caffe/examples/airline/data/2008.data.prep")
d$y <- ifelse(d$V8<0,"N","Y")
d$V8 <- NULL
d <- d[sample(nrow(d)),]
write.csv(d[100001:nrow(d)],"train.csv", row.names=FALSE)
write.csv(d[1:100000],"test.csv", row.names=FALSE)
## fit GBM and get accuracy (the evaluation metric used in the paper)
R:
library(h2o)
h2o.init(max_mem_size="60g", nthreads=-1)
dx_train <- h2o.importFile("train.csv")
dx_test <- h2o.importFile("test.csv")
dx_train_split <- h2o.splitFrame(dx_train, ratios = c(0.98), seed = 123)
system.time({
md <- h2o.gbm(x = 1:(ncol(dx_train)-1), y = ncol(dx_train),
training_frame = dx_train_split[[1]],
ntrees = 100, max_depth = 20, learn_rate = 0.1, nbins = 100,
validation_frame = dx_train_split[[2]],
stopping_rounds = 5, stopping_metric = "AUC", stopping_tolerance = 1e-3,
seed = 123)
})
sum(h2o.predict(md, dx_test[,1:(ncol(dx_train)-1)])[,1]==dx_test[,ncol(dx_train)])/nrow(dx_test)
## GBM accuracy: 0.811 (with no tuning, just the first setting that came to mind)
## SV-DKL accuracy (from paper): 0.781
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment