I will use this gist as training grounds to play around and learn about iterators in Julia. This was motivated by reading this blog post by Lorenzo Stella.
Disclaimer: At this stage I have a tiny example that sort-of works 😃
I will use this gist as training grounds to play around and learn about iterators in Julia. This was motivated by reading this blog post by Lorenzo Stella.
Disclaimer: At this stage I have a tiny example that sort-of works 😃
# Example function to find the root of | |
f(x) = [(x[1] + x[2] + 1)^3, (x[1] - x[2] - 1)^3] | |
function f_Jac!(J, x) | |
J[1, 1] = 3(x[1] + x[2] + 1)^2 | |
J[2, 1] = 3(x[1] - x[2] - 1)^2 | |
J[1, 2] = 3(x[1] + x[2] + 1)^2 | |
J[2, 2] = -3(x[1] - x[2] - 1)^2 | |
return J | |
end | |
# Check that my Jacobian is correct | |
using ForwardDiff | |
xtest = rand(2) | |
Jbuf = zeros(2, 2) | |
f_Jac!(Jbuf, xtest) # updates Jbuf | |
ForwardDiff.jacobian(f, xtest) ≈ Jbuf | |
# Now define the iterable | |
# I think it must contain the exact minimum of things | |
# that excatly defines the whole iteration: | |
# - x₀, the starting point | |
# - f, the function | |
# - f_Jac!, the Jacobian function | |
struct NewtonIterable{Tx, Tf, TJac} | |
x::Tx # the iterate | |
f::Tf # the function | |
f_Jac!::TJac # The Jacobian function | |
end | |
# Now define the state, | |
# which is a mutable struct to contain everything that may be needed for efficiency | |
mutable struct NewtonState{Tx, TJ, TJf} | |
x::Tx # current x | |
fx::Tx # current f(x) | |
Jx::TJ # current J(x) | |
Jfx::TJf # current factors of J(x) | |
end | |
using LinearAlgebra | |
import Base: iterate | |
function iterate(iter::NewtonIterable) | |
fx = iter.f(iter.x) | |
Jx = iter.f_Jac!(zeros(2, 2), iter.x) | |
state = NewtonState(iter.x, fx, Jx, lu(Jx)) | |
return state, state | |
end | |
function iterate(iter::NewtonIterable, state) | |
state.x .-= state.Jfx \ state.fx | |
state.fx = iter.f(state.x) | |
state.Jx = iter.f_Jac!(state.Jx, state.x) | |
state.Jfx = lu(state.Jx) | |
return state, state | |
end | |
for state in NewtonIterable(zeros(2), f, f_Jac!) | |
println("x:\n",state.x) | |
println("f(x):\n",state.fx) | |
if norm(state.fx) < 1e-20 break end | |
end | |