Last active
December 12, 2021 15:24
-
-
Save ornithos/6e5e5da1c2368c45f09777976cf647f0 to your computer and use it in GitHub Desktop.
A (very) simple multivariate linear regression with Julia / Flux
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Statistics | |
using Formatting | |
using Flux.Tracker | |
using Flux.Tracker: update! | |
# Define simple linear model | |
W = param(rand(2, 5)) | |
b = param(rand(2, 1)) | |
function fwd(x, y, W, b) | |
ŷ = W*x .+ b | |
loss = mean((y .- ŷ).^2) | |
return loss | |
end | |
# create dummy data | |
Wtrue, btrue = randn(2,5), randn(2,1) | |
x = randn(5, 100) | |
y = Wtrue * x .+ btrue + randn(2, 100); | |
# perform simple gradient descent | |
η = 1e-2 | |
for i = 1:10000 | |
loss = fwd(x, y, W, b) | |
Tracker.back!(loss) | |
update!(W, -η*Tracker.grad(W)) | |
update!(b, -η*Tracker.grad(b)) | |
end | |
# compare | |
println("Loss of true parameters: " * sprintf1("%.2f", fwd(x, y, Wtrue, btrue))) | |
println("Training loss of optimised parameters: " * sprintf1("%.2f", fwd(x, y, W, b))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
[DEPRECATED] N.b. this was written for the purposes of teaching. Just for the sake of posterity, I stress that it should not ever be used to perform linear regression in practice. And it is also using an outdated API and implementation of Flux, hence it is deprecated.