Created
April 5, 2019 22:46
-
-
Save goretkin/b70b108844189c3b3ba35a7620e7be48 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra: norm | |
using Zygote | |
using Plots | |
A = rand(2,2) | |
b = rand(2) | |
A = [0.511125 0.493421; 0.987115 0.260349] | |
b = [0.205398, 0.282329] | |
function error1(θ) | |
norm(A * θ - b) | |
end | |
function error2(θ) | |
# θ₂ is a matrix | |
(θ₁, θ₂) = θ | |
norm(A * θ₂ * θ₁ - b) | |
end | |
function map_parameter(θ_model2) | |
(θ₁, θ₂) = θ_model2 | |
θ_model1 = θ₂ * θ₁ | |
end | |
function gradient_descent(f, x0; α=0.001) | |
xs = [x0] | |
for i = 1:1000 | |
x_n = xs[end] | |
δfδx = f'(x_n) | |
# this looks like this to handle parameters that are tuples of vector spaces. | |
x_n_plus_1 = x_n .- (Ref(α) .* δfδx) | |
push!(xs, x_n_plus_1) | |
end | |
return xs | |
end | |
θ₀_model2 = (rand(2), rand(2,2)) | |
θ₀_model2 = ([0.0462594, 0.384705], [0.854173 0.693254; 0.0489713 0.838112]) | |
θ₀_model1 = map_parameter(θ₀_model2) | |
θs_model1 = gradient_descent(error1, θ₀_model1) | |
θs_model2 = gradient_descent(error2, θ₀_model2) | |
θs_model2_mapped = map_parameter.(θs_model2) | |
parameter_trajectory_plot = plot(reuse=false) | |
title!(parameter_trajectory_plot, "parameter trajectory") | |
xy1 = map(x->tuple(x...), θs_model1) | |
xy2 = map(x->tuple(x...), θs_model2_mapped) | |
plot!(parameter_trajectory_plot, xy1; color=:red) | |
plot!(parameter_trajectory_plot, xy2; color=:blue) | |
error_trajectory_plot = plot(reuse=false) | |
title!(error_trajectory_plot, "error plot") | |
plot!(error_trajectory_plot, error1.(θs_model1); color=:red) | |
plot!(error_trajectory_plot, error2.(θs_model2); color=:blue) | |
display(parameter_trajectory_plot) | |
display(error_trajectory_plot) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment