Skip to content

Instantly share code, notes, and snippets.

@JulienPascal
Created February 28, 2022 14:09
Show Gist options
  • Save JulienPascal/7f1c3ef9dcc544e93b2cc449c8343967 to your computer and use it in GitHub Desktop.
Save JulienPascal/7f1c3ef9dcc544e93b2cc449c8343967 to your computer and use it in GitHub Desktop.
OLS_ML_3
dim_input=2 #dim of input, without the intercept
dim_output=1
# Normal noise
d = Normal()
# True parameters
beta = rand(d, dim_input + 1);
# Noise
e = rand(d, n_points);
# Input data:
X = rand(d, (n_points,dim_input));
# Add the intercept:
X = hcat(ones(n_points),X);
#Linear Model
y = X*beta .+ e;
function obj_function(X,y,beta)
mean((y .- X*beta).^2 )
end
beta_1_grid = collect(range(-10.0, 10.0, length=100))
beta_2_grid = copy(beta_1_grid)
plot(beta_1_grid, beta_2_grid, (x,y) -> obj_function(X, y, [beta[1]; x; y]), st=:contour, colorbar_title=L"|X-y\hat{\beta}|^2")
scatter!([beta[2]], [beta[3]], markershape = :star5)
xlabel!(L"\beta_1")
ylabel!(L"\beta_2")
# refinement loop
beta_hat = [beta[1]; -9.0; 9.0] #fix the intercept at the true value. Random guess for beta_1 and beta_2
grad_n = zeros(3) #initialize gradient
r = 1e-5 #learning rate
anim = @animate for i=1:50
grad_OLS!(grad_n, beta_hat, X, y)
beta_hat[:] -= r*grad_n
scatter!([beta_hat[2]], [beta_hat[3]], legend=:none)
end
gif(anim,joinpath(dirname(@__FILE__),"convergence_GD_OLS_2d.gif"),fps=5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment