Skip to content

Instantly share code, notes, and snippets.

@jilmun
Last active March 9, 2016 21:02
Show Gist options
  • Save jilmun/5370110f3761e9b3e697 to your computer and use it in GitHub Desktop.
Save jilmun/5370110f3761e9b3e697 to your computer and use it in GitHub Desktop.
# create dummy data for testing
require(caret)
require(dplyr)
full <- data.frame(target = sample(c(0,1), 500, replace=T),
ID = 1:500,
v1 = sample(LETTERS[1:2], 500, replace=T),
v2 = sample(1:100, 500, replace=T),
v3 = sample(LETTERS[1:10], 500, replace=T),
stringsAsFactors = FALSE)
folds <- createFolds(full$target, k=5, list=TRUE, returnTrain=FALSE)
train <- full
#--------------------------------------
## original code from happy
ohe.list <- names(full[, sapply(full, is.character)])
# > ohe.list
# [1] "v1" "v3"
list.x <- list()
for (i in seq(1, length(folds)) ) {
lofo.ids <- folds[[i]]
x <- train[-lofo.ids, ] %>%
group_by(v1) %>%
mutate(v1_avg = mean(target, na.rm = T)) %>%
ungroup %>%
select(ID, v1_avg)
list.x[[i]] <- x
}
list.x <- do.call(rbind, list.x)
# Source: local data frame [2,000 x 2]
#
# ID v1_avg
# (int) (dbl)
# 1 1 0.5242718
# 2 2 0.4639175
# 3 3 0.4639175
# 4 6 0.4639175
# 5 7 0.5242718
# 6 8 0.4639175
# 7 9 0.5242718
# 8 10 0.4639175
# 9 11 0.5242718
# 10 12 0.4639175
# .. ... ...
#--------------------------------------
## revised code
encode_cat <- function(mydat) { # col1=target, col2=categorical var
names(mydat)[2] <- "myvar"
target.avg <- mydat %>% group_by(myvar) %>%
mutate(v_avg = mean(target, na.rm = T)) %>%
ungroup %>% select(v_avg)
return(target.avg)
}
ohe.list <- names(full[, sapply(full, is.character)]) # v1, v3
list.x <- list()
for (i in seq(1, length(folds)) ) {
lofo.ids <- folds[[i]]
train.foldi <- train[-lofo.ids, ]
x <- train.foldi %>% select(ID)
for (j in 1:length(ohe.list)) {
x <- bind_cols(x, encode_cat(train.foldi[c("target", ohe.list[j])]))
names(x)[j+1] <- paste0(ohe.list[j],"_avg")
}
list.x[[i]] <- x
}
list.x <- do.call(rbind, list.x)
# Source: local data frame [2,000 x 3]
#
# ID v1_avg v3_avg
# (int) (dbl) (dbl)
# 1 1 0.5242718 0.5151515
# 2 2 0.4639175 0.4615385
# 3 3 0.4639175 0.5151515
# 4 6 0.4639175 0.5000000
# 5 7 0.5242718 0.4705882
# 6 8 0.4639175 0.4047619
# 7 9 0.5242718 0.4705882
# 8 10 0.4639175 0.5833333
# 9 11 0.5242718 0.4615385
# 10 12 0.4639175 0.5625000
# .. ... ... ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment