Skip to content

Instantly share code, notes, and snippets.

@DexGroves
Last active February 29, 2016 19:21
Show Gist options
  • Save DexGroves/4e6569eb286171dfd3da to your computer and use it in GitHub Desktop.
Save DexGroves/4e6569eb286171dfd3da to your computer and use it in GitHub Desktop.
library("data.table")
library("xgboost")
library("Matrix")
generate_data <- function(N) {
data.table(
response = as.numeric(runif(N) > 0.8),
int1 = round(rnorm(N, 3, 3)),
int2 = round(rnorm(N, 3, 3)),
int3 = round(rnorm(N, 3, 3))
)
}
N <- 1000
set.seed(1235) # This seed splits at a point that reproduces bug
mform <- as.formula(response ~ int1 + int2 + int3)
train <- generate_data(N)
lbl_train <- train[, response]
smm_train <- sparse.model.matrix(mform, train)
dtrain <- xgb.DMatrix(data = smm_train, label = lbl_train)
model <- xgb.train(params = list(eta = 1,
max_depth = 1,
min_child_weight = 10,
subsample = 1.0,
objective = "binary:logistic",
eval_metric = "logloss"),
data = dtrain,
nrounds = 1,
nthread = 1,
verbose = 1,
print.every.n = 1,
save_period = 0,
save_name = "xgboost.model")
xgb.dump(model = model)
# -> Split f3 (int3) at 3.5
train[, pred := predict(model, dtrain)]
train[, mean(pred), by = int3][order(int3)]
# int3 V1
# 1: -6 0.2593045
# 2: -5 0.2593045
# 3: -4 0.2593045
# 4: -3 0.2593045
# 5: -2 0.2593045
# 6: -1 0.2593045
# 7: 0 0.2197411 <- whoa dude!
# 8: 1 0.2593045
# 9: 2 0.2593045
# 10: 3 0.2593045
# 11: 4 0.2197411
# 12: 5 0.2197411
# 13: 6 0.2197411
# 14: 7 0.2197411
# 15: 8 0.2197411
# 16: 9 0.2197411
# 17: 10 0.2197411
# 18: 11 0.2197411
# 19: 12 0.2197411
# 20: 15 0.2197411
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment