Skip to content

Instantly share code, notes, and snippets.

@cmcaine
Last active January 18, 2020 19:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cmcaine/426db5d7f9af402c7bd76eeb12bffd7f to your computer and use it in GitHub Desktop.
Save cmcaine/426db5d7f9af402c7bd76eeb12bffd7f to your computer and use it in GitHub Desktop.
Simpler Flux demos
"""
Learn a linear model for a linear relationship of three variables
"""
module MultipleLinearRegressionDemo
using Flux: gradient, params
using Statistics: mean
function linreg(X, Y)
# Add the intercept column to X
X = hcat(ones(size(X)[1]), X)
# The weight for each independent variable and the intercept
W = randn(size(X)[2])
# Prediction function (a linear model)
predict(X) = X * W
# Ordinary Least Squares loss function: mean squared error
# Use mean rather than sum so that learning_rate is independent of input size.
loss() = mean((Y - predict(X)).^2)
# Learn
# Compute the current loss, and differentiate with respect to W
# Update W in the opposite direction of the gradient (i.e. descend)
learning_rate = 0.01
for i = 1:3000
gs = gradient(loss, params(W))
W .-= learning_rate .* gs[W]
end
return W
end
# Training data
X = randn(10, 2)
Y = X[:, 1] .* 3 .+ X[:, 2] .* .5 .+ 1
# We're expecting: 1, 3, .5
@show linreg(X, Y)
end
"""
Learn a linear model for data that really does have a linear relationship
"""
module SimpleLinearRegressionDemo
using Flux: gradient, params
using Statistics: mean
function linreg(xs, ys)
# Parameters have to wrapped so that they can be updated in-place.
c = [randn()]
m = [randn()]
# Prediction function (a linear model y = mx + c)
predict(x) = m[1] * x + c[1]
# Loss function: squared error
loss(x, y) = (predict(x) - y)^2
# Learn
# Compute the current loss, and differentiate with respect to m and c.
# Update c and m in the opposite direction of the gradient (i.e. descend)
learning_rate = 0.01
for i = 1:3000
gs = gradient(() -> mean(loss.(xs, ys)), params(c, m))
c .-= learning_rate .* gs[c]
m .-= learning_rate .* gs[m]
end
return vcat(c, m)
end
# Training data
xs = 1:10
ys = xs .* 3 .+ 2;
# I expect c should approach 2 and m should approach 3.
@show linreg(xs, ys)
end
@cmcaine
Copy link
Author

cmcaine commented Jan 18, 2020

# Add the intercept column to X
X = [1; X']'

I wonder if that's too magical and if I should do this instead?

# Add the intercept column to X
X = hcat(ones(size(X)[1]), X)

@cmcaine
Copy link
Author

cmcaine commented Jan 18, 2020

Another oddity to examine before proposing in a PR is why the simple linear regression requires me to calculate the mean of the loss rather than the sum.

if I try the sum, the loss rapidly approaches infinity and I end up with [NaN, NaN].

@cmcaine
Copy link
Author

cmcaine commented Jan 18, 2020

Yup, it's just the learning rate is too high for sum. If reduced to 0.002, it performs well.

Using a dynamic learning rate would probably fix it.

@caseykneale
Copy link

Well.. You do want the mean square error because then your learning rate depends on batch size :). It is also the BLUE for LS, so it has a definitional purpose.

@cmcaine
Copy link
Author

cmcaine commented Jan 18, 2020

That is a very good point ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment