Skip to content

Instantly share code, notes, and snippets.

@goretkin
Created April 5, 2019 22:46
Show Gist options
  • Save goretkin/b70b108844189c3b3ba35a7620e7be48 to your computer and use it in GitHub Desktop.
Save goretkin/b70b108844189c3b3ba35a7620e7be48 to your computer and use it in GitHub Desktop.
using LinearAlgebra: norm
using Zygote
using Plots
A = rand(2,2)
b = rand(2)
A = [0.511125 0.493421; 0.987115 0.260349]
b = [0.205398, 0.282329]
function error1(θ)
norm(A * θ - b)
end
function error2(θ)
# θ₂ is a matrix
(θ₁, θ₂) = θ
norm(A * θ₂ * θ₁ - b)
end
function map_parameter(θ_model2)
(θ₁, θ₂) = θ_model2
θ_model1 = θ₂ * θ₁
end
function gradient_descent(f, x0; α=0.001)
xs = [x0]
for i = 1:1000
x_n = xs[end]
δfδx = f'(x_n)
# this looks like this to handle parameters that are tuples of vector spaces.
x_n_plus_1 = x_n .- (Ref(α) .* δfδx)
push!(xs, x_n_plus_1)
end
return xs
end
θ₀_model2 = (rand(2), rand(2,2))
θ₀_model2 = ([0.0462594, 0.384705], [0.854173 0.693254; 0.0489713 0.838112])
θ₀_model1 = map_parameter(θ₀_model2)
θs_model1 = gradient_descent(error1, θ₀_model1)
θs_model2 = gradient_descent(error2, θ₀_model2)
θs_model2_mapped = map_parameter.(θs_model2)
parameter_trajectory_plot = plot(reuse=false)
title!(parameter_trajectory_plot, "parameter trajectory")
xy1 = map(x->tuple(x...), θs_model1)
xy2 = map(x->tuple(x...), θs_model2_mapped)
plot!(parameter_trajectory_plot, xy1; color=:red)
plot!(parameter_trajectory_plot, xy2; color=:blue)
error_trajectory_plot = plot(reuse=false)
title!(error_trajectory_plot, "error plot")
plot!(error_trajectory_plot, error1.(θs_model1); color=:red)
plot!(error_trajectory_plot, error2.(θs_model2); color=:blue)
display(parameter_trajectory_plot)
display(error_trajectory_plot)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment