Skip to content

Instantly share code, notes, and snippets.

@ornithos
Last active December 12, 2021 15:24
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ornithos/6e5e5da1c2368c45f09777976cf647f0 to your computer and use it in GitHub Desktop.
Save ornithos/6e5e5da1c2368c45f09777976cf647f0 to your computer and use it in GitHub Desktop.
A (very) simple multivariate linear regression with Julia / Flux
using Statistics
using Formatting
using Flux.Tracker
using Flux.Tracker: update!
# Define simple linear model
W = param(rand(2, 5))
b = param(rand(2, 1))
function fwd(x, y, W, b)
ŷ = W*x .+ b
loss = mean((y .- ŷ).^2)
return loss
end
# create dummy data
Wtrue, btrue = randn(2,5), randn(2,1)
x = randn(5, 100)
y = Wtrue * x .+ btrue + randn(2, 100);
# perform simple gradient descent
η = 1e-2
for i = 1:10000
loss = fwd(x, y, W, b)
Tracker.back!(loss)
update!(W, -η*Tracker.grad(W))
update!(b, -η*Tracker.grad(b))
end
# compare
println("Loss of true parameters: " * sprintf1("%.2f", fwd(x, y, Wtrue, btrue)))
println("Training loss of optimised parameters: " * sprintf1("%.2f", fwd(x, y, W, b)))
@ornithos
Copy link
Author

[DEPRECATED] N.b. this was written for the purposes of teaching. Just for the sake of posterity, I stress that it should not ever be used to perform linear regression in practice. And it is also using an outdated API and implementation of Flux, hence it is deprecated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment