Skip to content

Instantly share code, notes, and snippets.

@JulienPascal
Created February 28, 2022 14:11
Show Gist options
  • Save JulienPascal/89c0a9ba32fd2bc168219c342782be6f to your computer and use it in GitHub Desktop.
Save JulienPascal/89c0a9ba32fd2bc168219c342782be6f to your computer and use it in GitHub Desktop.
OLS_ML_4
plot(beta_1_grid, beta_2_grid, (x,y) -> obj_function(X, y, [beta[1]; x; y]), st=:contour, colorbar_title=L"|X-y\hat{\beta}|^2")
scatter!([beta[2]], [beta[3]], markershape = :star5)
xlabel!(L"\beta_1")
ylabel!(L"\beta_2")
# refinement loop
beta_hat_sgd = [beta[1]; -9.0; 9.0] #fix the intercept at the true value. Random guess for beta_1 and beta_2
beta_hat = [beta[1]; -9.0; 9.0]
grad_n_sgd = zeros(3) #initialize gradient
grad_n = zeros(3) #initialize gradient
r_sgd = 1e-2#learning rate for stochastic gradient descent
r = 1e-5 #learning rate for gradient descent
n_y = 5 #number of points from the sample for stochastic gradient descent
anim = @animate for i=1:50
y_index = rand(1:size(y,1), n_y)# select a subset of the sample
grad_OLS!(grad_n_sgd, beta_hat_sgd, X[y_index,:], y[y_index])
grad_OLS!(grad_n, beta_hat, X, y)
beta_hat_sgd[:] -= r_sgd*grad_n_sgd
beta_hat[:] -= r*grad_n
scatter!([beta_hat_sgd[2]], [beta_hat_sgd[3]], markershape=:xcross, markersize=5, legend=:none)
scatter!([beta_hat[2]], [beta_hat[3]])
end
gif(anim,joinpath(dirname(@__FILE__),"convergence_GD_OLS_2d_2.gif"),fps=5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment