# Subset of the million-song dataset
# Is a song made before or after 2002, based on its vocal features?
# significantly handicapping gbm here:
# - no cv for stopping
# - forcing 0.5 threshold.
# The above is what spark has to do
# Download data and get raw
f <- "/tmp/YearPredictionMSD.txt"
if (! file.exists(f)) {
z <- paste(f, ".zip", sep="")
download.file("", z, quiet = TRUE)
unzip(z, exdir="/tmp")
raw <- read.csv(f, header=FALSE)
# Data is already fairly clean (no nas). Just split as prescribed by the website
raw[,1] = as.numeric(raw[,1])
cutoff <- 463715
train <- raw[1:cutoff,]
test <- raw[-(1:cutoff),]
raw <- NULL
response_col = colnames(train)[1]
# convert to a binary problem by checking whether we're before or after 2002
# this splits the dataset in about half.
train[[response_col]] = as.numeric(train[[response_col]] < 2002)
test[[response_col]] = as.numeric(test[[response_col]] < 2002)
# train
ntrees <- 700
shrinkage <- 0.001
gbm_formula <- as.formula(paste0(response_col, " ~ ", paste(colnames(train)[2:ncol(train)], collapse = " + ")))
duration <- proc.time()
gbm_model <- gbm(gbm_formula, train, distribution = "bernoulli", n.trees = ntrees, bag.fraction = 0.75, interaction.depth = 3, n.cores=4, shrinkage = shrinkage)
duration <- proc.time() - duration
# total time
# predict
predictions_gbm <- predict(gbm_model, newdata = test[, 2:ncol(test)],
n.trees = ntrees, type = "response")
pred <- prediction(predictions_gbm, test[[response_col]], label.ordering = NULL)
# evaluate
gbm_perf <- gbm.perf(gbm_model, method = "OOB")
plot(performance(pred, measure = "tpr", x.measure = "fpr"))
performance(pred, measure = "auc")@y.values
print("acc (thresh=0.5)")
acc <- performance(pred, measure = "acc")
acc@y.values[[1]][max(which(acc@x.values[[1]] >= 0.5))]
print("precision (thresh=0.5)")
acc <- performance(pred, measure = "prec")
acc@y.values[[1]][max(which(acc@x.values[[1]] >= 0.5))]
print("recall (thresh=0.5)")
acc <- performance(pred, measure = "rec")
acc@y.values[[1]][max(which(acc@x.values[[1]] >= 0.5))]
Loading required package: survival
Loading required package: lattice
Loading required package: splines
Loading required package: parallel
Loaded gbm 2.1.1
Loading required package: gplots
Attaching package: ‘gplots’
The following object is masked from ‘package:stats’:
Loading required package: methods
user system elapsed
5709.492 2.464 5715.402
Warning message:
In gbm.perf(gbm_model, method = "OOB") :
OOB generally underestimates the optimal number of iterations although predictive performance is reasonably competitive. Using cv.folds>0 when calling gbm usually results in improved predictive performance.
[1] "auc"
[1] 0.6951813
[1] "acc (thresh=0.5)"
[1] 0.6398993
[1] "precision (thresh=0.5)"
[1] 0.6806732
[1] "recall (thresh=0.5)"
[1] 0.4725602
