Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save klauszhang/91de6face2e354224ab096a6b0ce9168 to your computer and use it in GitHub Desktop.
Save klauszhang/91de6face2e354224ab096a6b0ce9168 to your computer and use it in GitHub Desktop.
Linear regression by gradient descent
##
## Linear regression by gradient descent
##
## A learning exercise to help build intuition about gradient descent.
## J. Christopher Bare, 2012
##
# set random seed
set.seed(12345)
# generate random data in which y is a noisy function of x
x <- runif(1000,-5, 5)
y <- x + rnorm(1000) + 3
# fit a linear model
res <- lm(y ~ x)
# plot the data and the model
plot(x, y, col = rgb(0.2, 0.4, 0.6, 0.4), main = 'Linear regression')
abline(res, col = 'blue')
# squared error cost function
cost <- function(X, y, theta) {
sum((X %*% theta - y) ^ 2) / (2 * length(y))
}
# learning rate and iteration limit
alpha <- 0.01
num_iters <- 1000
# keep history
cost_history <- double(num_iters)
theta_history <- list(num_iters)
# initialize coefficients
theta <- matrix(c(0, 0), nrow = 2)
# add a column of 1's for the intercept coefficient
# to vectorize the calculation.
# X is a matrix of m x n+1 , theta is a vector of m x 1
X <- cbind(1, matrix(x))
h <- function(X, theta) {
X %*% theta
}
# gradient descent
for (i in 1:num_iters) {
error <- (h(X, theta) - y)
delta <- t(X) %*% error / length(y)
#update all theta
theta <- theta - alpha * delta
cost_history[i] <- cost(X, y, theta)
theta_history[[i]] <- theta
}
# plot data and converging fit
plot(x, y, col = rgb(0.2, 0.4, 0.6, 0.4), main = 'Linear regression by gradient descent')
for (i in c(1, 3, 6, 10, 14, seq(20, num_iters, by = 10))) {
abline(coef = theta_history[[i]], col = rgb(0.8, 0, 0, 0.3))
}
abline(coef = theta, col = "blue")
# check out the trajectory of the cost function
cost_history[seq(1, num_iters, by = 100)]
plot(
cost_history,
type = 'l',
col = 'blue',
lwd = 2,
main = 'Cost function',
ylab = 'cost',
xlab = 'Iterations'
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment