# Reference R implementation | |
# quite optimized, only 7-15x slower | |
FTRL_R = function(alpha, beta, lambda1, lambda2, nfeature, z = NULL, n = NULL) { | |
z = init_ftrl_param(z, n_features) | |
n = init_ftrl_param(n, n_features) | |
alpha = alpha | |
beta = beta | |
lambda1 = lambda1 | |
lambda2 = lambda2 | |
########################################################################################## | |
sigmoid = function(x) { | |
1 / (1 + exp(-x)) | |
} | |
########################################################################################## | |
w_ftprl = function(i) { | |
retval = numeric(length(i)) | |
# index = which(abs(z[i]) > lambda1) | |
index = abs(z[i]) > lambda1 | |
j = i[index] | |
z_j = z[j] | |
n_j = n[j] | |
retval[index] = - (z_j - sign(z_j) * lambda1) / (lambda2 + (beta + sqrt(n_j)) / alpha) | |
retval | |
} | |
########################################################################################## | |
predict_internal = function(j, x) { | |
w = w_ftprl(j) | |
# print(w) | |
sigmoid(crossprod(x, w)[[1L]]) | |
} | |
########################################################################################## | |
partial_fit = function(x, y, with_pb = interactive()) {#x_cv = NULL, y_cv = NULL, check_each_n = 1e5, j = 1:1e5 ) { | |
# cv_train_n = length(j) | |
# x_cv_train = x[, j] | |
p = numeric(ncol(x)) | |
if (with_pb) | |
pb = txtProgressBar(max = ncol(x), style = 3) | |
for(col in seq_len(ncol(x))) { | |
# if(col %% 1e4 == 0) { | |
# message(paste(Sys.time(), "sample", col)) | |
# } | |
index = | |
if (x@p[[col]] == x@p[[col + 1L]]) integer(0) | |
else seq.int(x@p[[col]], x@p[[col + 1L]] - 1L, by = 1L) | |
i = x@i[index + 1L] + 1L | |
xx = x@x[index + 1L] | |
p[[col]] = predict_internal(i, xx) | |
# if(col %% check_each_n == 0) | |
# message(p[[col]]) | |
n_i = n[i] | |
z_i = z[i] | |
g = (p[[col]] - y[[col]]) * xx | |
n_i_g2 = n_i + g * g | |
s = (sqrt(n_i_g2) - sqrt(n_i)) / alpha | |
z[i] <<- z_i + g - s * w_ftprl(i) | |
n[i] <<- n_i_g2 | |
# print(z) | |
# if(col %% check_each_n == 0 && !is.null(x_cv)) { | |
# if(!is.null(x_cv) && !is.null(y_cv)) { | |
# cv_score = round(glmnet::auc(y = y_cv, prob = predict(x_cv, FALSE)), 4) | |
# train_score = round(glmnet::auc(y = y[j], prob = predict(x_cv_train, FALSE)), 4) | |
# message(Sys.time(), " ", col, " - ", "cv = ", cv_score, " train = ", train_score) | |
# } | |
# } | |
if (with_pb) | |
setTxtProgressBar(pb, col) | |
} | |
if (with_pb) | |
close(pb) | |
list(p = p, z = z, n = n) | |
} | |
predict = function(x, with_pb = interactive(), check = 10) { | |
p = numeric(ncol(x)) | |
if (with_pb) | |
pb = txtProgressBar(max = ncol(x), style = 3) | |
for(col in seq_len(ncol(x))) { | |
index = | |
if (x@p[[col]] == x@p[[col + 1L]]) | |
integer(0) | |
else | |
seq.int(x@p[[col]], x@p[[col + 1L]] - 1L, by = 1L) | |
i = x@i[index + 1L] + 1L | |
xx = x@x[index + 1L] | |
p[[col]] = predict_internal(i, xx) | |
# if(col %% check == 0) | |
# message(p[[col]]) | |
if (with_pb) | |
setTxtProgressBar(pb, col) | |
} | |
if (with_pb) | |
close(pb) | |
p | |
} | |
list(predict = predict, partial_fit = partial_fit) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment