Skip to content

Instantly share code, notes, and snippets.

@antoine-levitt
Created July 21, 2019 16:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antoine-levitt/1c08f3181df31b81e912541b17d5acf9 to your computer and use it in GitHub Desktop.
Save antoine-levitt/1c08f3181df31b81e912541b17d5acf9 to your computer and use it in GitHub Desktop.
diff --git a/src/solvers/anderson.jl b/src/solvers/anderson.jl
index 0d47b6a..60dc122 100644
--- a/src/solvers/anderson.jl
+++ b/src/solvers/anderson.jl
@@ -42,8 +42,8 @@ AndersonCache(df, ::Anderson{0}) =
@views function anderson_(df::Union{NonDifferentiable, OnceDifferentiable},
initial_x::AbstractArray{T},
- xtol::T,
- ftol::T,
+ xtol::Real,
+ ftol::Real,
iterations::Integer,
store_trace::Bool,
show_trace::Bool,
@@ -80,7 +80,7 @@ AndersonCache(df, ::Anderson{0}) =
update!(tr,
iter,
maximum(abs, fx),
- iter > 1 ? sqeuclidean(cache.g, cache.x) : convert(T,NaN),
+ iter > 1 ? sqeuclidean(cache.g, cache.x) : convert(real(T),NaN),
dt,
store_trace,
show_trace)
@@ -187,5 +187,5 @@ function anderson(df::Union{NonDifferentiable, OnceDifferentiable},
aa_start::Integer,
droptol::Real,
cache::AndersonCache) where T
- anderson_(df, initial_x, convert(T, xtol), convert(T, ftol), iterations, store_trace, show_trace, extended_trace, beta, aa_start, droptol, cache)
+ anderson_(df, initial_x, convert(real(T), xtol), convert(real(T), ftol), iterations, store_trace, show_trace, extended_trace, beta, aa_start, droptol, cache)
end
diff --git a/src/solvers/broyden.jl b/src/solvers/broyden.jl
index 97173eb..08fc21e 100644
--- a/src/solvers/broyden.jl
+++ b/src/solvers/broyden.jl
@@ -23,8 +23,8 @@ end
function broyden_(df::Union{NonDifferentiable, OnceDifferentiable},
initial_x::AbstractArray{T},
- xtol::T,
- ftol::T,
+ xtol::Real,
+ ftol::Real,
iterations::Integer,
store_trace::Bool,
show_trace::Bool,
@@ -120,7 +120,7 @@ function broyden(df::Union{NonDifferentiable, OnceDifferentiable},
show_trace::Bool,
extended_trace::Bool,
linesearch) where T
- broyden_(df, initial_x, convert(T, xtol), convert(T, ftol), iterations, store_trace, show_trace, extended_trace, linesearch)
+ broyden_(df, initial_x, convert(real(T), xtol), convert(real(T), ftol), iterations, store_trace, show_trace, extended_trace, linesearch)
end
# A derivative-free line search and global convergence
diff --git a/test/complex.jl b/test/complex.jl
index b4b88b6..610b34a 100644
--- a/test/complex.jl
+++ b/test/complex.jl
@@ -1,22 +1,24 @@
@testset "complex" begin
function f!(F, x)
- F[1] = x[1]*x[2] + 1
- F[2] = x[1]^2 + x[2]^2 - 2
+ F[1] = x[1]*x[2] + (1+im)
+ F[2] = x[1]^2 + x[2]^2 - (2-3im)
end
function f_real!(F::AbstractArray{T}, x::AbstractArray{T}) where {T<:Real}
f!(reinterpret(Complex{T}, F), reinterpret(Complex{T}, x))
end
-for alg in [:trust_region, :newton]
- sol = nlsolve(f!, [1.0+0.1im, 2+1im], method = alg, store_trace=true, extended_trace=true)
- sol_real = nlsolve(f_real!, reinterpret(Float64, [1.0+0.1im, 2+1im]), method = alg, store_trace=true, extended_trace=true)
+for alg in [:newton,:trust_region,:anderson] # TODO add broyden
+ sol = nlsolve(f!, [1.0+0.1im, 2+1im], method = alg, store_trace=true, extended_trace=true, iterations=100, m=10, beta=0.01)
+ sol_real = nlsolve(f_real!, reinterpret(Float64, [1.0+0.1im, 2+1im]), method = alg, store_trace=true, extended_trace=true, iterations=100, m=10, beta=0.01)
@test converged(sol) == converged(sol_real)
@test sol.zero ≈ reinterpret(ComplexF64, sol_real.zero)
- @test sol.iterations == sol_real.iterations
- @test sol.f_calls == sol_real.f_calls
- @test sol.g_calls == sol_real.g_calls
- @test all(sol_real.trace[i].stepnorm == sol_real.trace[i].stepnorm for i in 2:sol.iterations)
- @test all(norm(sol.trace[i].metadata["f(x)"]) ≈ norm(sol_real.trace[i].metadata["f(x)"]) for i in 1:5)
+ if alg in (:newton, :trust_region) #those are supposed to be exactly the same (in exact arithmetic)
+ @test sol.iterations == sol_real.iterations
+ @test sol.f_calls == sol_real.f_calls
+ @test sol.g_calls == sol_real.g_calls
+ @test all(sol_real.trace[i].stepnorm == sol_real.trace[i].stepnorm for i in 2:sol.iterations)
+ @test all(norm(sol.trace[i].metadata["f(x)"]) ≈ norm(sol_real.trace[i].metadata["f(x)"]) for i in 1:5)
+ end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment