Skip to content

Instantly share code, notes, and snippets.

@JulienPascal
Last active February 28, 2022 14:07
Show Gist options
  • Save JulienPascal/77d749459dcdb58981df7a87d479506c to your computer and use it in GitHub Desktop.
Save JulienPascal/77d749459dcdb58981df7a87d479506c to your computer and use it in GitHub Desktop.
OLS_ML_2.jl
#Calculate the gradient
function grad_OLS!(G, beta_hat, X, y)
G[:] = transpose(X)*(X*beta_hat - y)
end
#Gradient descent way
function OLS_gd(X::Array, y::Vector; epochs::Int64=1000, r::Float64=1e-5, beta_hat = zeros(size(X,2)), verbose::Bool=false)
grad_n = zeros(size(X,2))
for epoch=1:epochs
grad_OLS!(grad_n, beta_hat, X, y)
beta_hat -= r*grad_n
if verbose==true
if mod(epoch, round(Int, epochs/10))==1
println("MSE: $(mse(beta_hat, X, y))")
end
end
end
return beta_hat
end
beta_hat_gd_1 = OLS_gd(X,y, epochs=1);
plot(beta, zeros(size(X,2)), seriestype=:scatter, label="GD (0 iter.)")
plot!(beta, beta, seriestype=:line, label="45° line", legend = :outertopleft)
xlabel!(L"True value $\beta$")
ylabel!(L"Estimated value $\hat{\beta}$ (GD)")
# refinement loop
anim = @animate for i=1:50
if mod(i, 5) == 0
beta_hat_gd = OLS_gd(X,y, epochs=i);
plot!(beta, beta_hat_gd, seriestype=:scatter, label="GD ($(i) iter.)")
plot!(beta, beta, seriestype=:line, label=:none, legend = :outertopleft)
xlabel!(L"True value $\beta$")
ylabel!(L"Estimated value $\hat{\beta}$ (GD)")
end
end
gif(anim,joinpath(dirname(@__FILE__),"convergence_GD_OLS.gif"),fps=5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment