Instantly share code, notes, and snippets.

# briochemc/NewtonIterator.jl

Last active February 13, 2019 05:20

# Learning how to use iterators in Julia

I will use this gist as training grounds to play around and learn about iterators in Julia. This was motivated by reading this blog post by Lorenzo Stella.

Disclaimer: At this stage I have a tiny example that sort-of works 😃

This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 # Example function to find the root of f(x) = [(x[1] + x[2] + 1)^3, (x[1] - x[2] - 1)^3] function f_Jac!(J, x) J[1, 1] = 3(x[1] + x[2] + 1)^2 J[2, 1] = 3(x[1] - x[2] - 1)^2 J[1, 2] = 3(x[1] + x[2] + 1)^2 J[2, 2] = -3(x[1] - x[2] - 1)^2 return J end # Check that my Jacobian is correct using ForwardDiff xtest = rand(2) Jbuf = zeros(2, 2) f_Jac!(Jbuf, xtest) # updates Jbuf ForwardDiff.jacobian(f, xtest) ≈ Jbuf # Now define the iterable # I think it must contain the exact minimum of things # that excatly defines the whole iteration: # - x₀, the starting point # - f, the function # - f_Jac!, the Jacobian function struct NewtonIterable{Tx, Tf, TJac} x::Tx # the iterate f::Tf # the function f_Jac!::TJac # The Jacobian function end # Now define the state, # which is a mutable struct to contain everything that may be needed for efficiency mutable struct NewtonState{Tx, TJ, TJf} x::Tx # current x fx::Tx # current f(x) Jx::TJ # current J(x) Jfx::TJf # current factors of J(x) end using LinearAlgebra import Base: iterate function iterate(iter::NewtonIterable) fx = iter.f(iter.x) Jx = iter.f_Jac!(zeros(2, 2), iter.x) state = NewtonState(iter.x, fx, Jx, lu(Jx)) return state, state end function iterate(iter::NewtonIterable, state) state.x .-= state.Jfx \ state.fx state.fx = iter.f(state.x) state.Jx = iter.f_Jac!(state.Jx, state.x) state.Jfx = lu(state.Jx) return state, state end for state in NewtonIterable(zeros(2), f, f_Jac!) println("x:\n",state.x) println("f(x):\n",state.fx) if norm(state.fx) < 1e-20 break end end