Last active
February 12, 2023 23:02
-
-
Save liangyy/489d1519dd45246caf4756d7722bfa25 to your computer and use it in GitHub Desktop.
Fast linear regression function in R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# simple and fast linear regression solver | |
# y ~ x + covariate | |
# input: | |
# y: n x 1 (vector) | |
# x: n x k (matrix) | |
# covariate: n x m (optional; matrix) | |
# | |
# output: | |
# coefficient of x[i] in y ~ x[i] + covariate | |
# and corresponding p-value (from t value) | |
fast_linear_regression = function(y, x, covariate = NULL) { | |
x = as.matrix(x) | |
if(is.null(covariate)) { | |
y_ = y | |
x_ = x | |
dof = length(y) - 1 | |
} else { | |
covariate = as.matrix(covariate) | |
res = qr(covariate) | |
Q_ = qr.Q(res) | |
x_ = x - Q_ %*% (t(Q_) %*% x) | |
y_ = as.numeric(y - Q_ %*% (t(Q_) %*% y)) | |
dof = length(y) - ncol(covariate) - 1 | |
} | |
bhat = colMeans(y_ * x_) / colMeans(x_ ^ 2) | |
sigma2 = colSums((y_ - sweep(x_, 2, bhat, "*")) ^ 2) / dof | |
se = sqrt(1 / colSums(x_ ^ 2) * sigma2 ) | |
pval = pt(abs(bhat / se), dof, lower.tail = F, log.p = T) | |
list(bhat = bhat, pval = exp(pval) * 2, se = se) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For revision on 2023/02/12, I changed
x_ * bhat
tosweep(x_, 2, bhat, "*")
to accommodate the by-column*
. This was not an issue in the previous R version (3.x) but it appears for R 4.0.x