Skip to content

Instantly share code, notes, and snippets.

@mateuszbaran
Created March 16, 2023 08:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mateuszbaran/0354c0edfb9cdf25e084a2b915816a09 to your computer and use it in GitHub Desktop.
Save mateuszbaran/0354c0edfb9cdf25e084a2b915816a09 to your computer and use it in GitHub Desktop.
Using Optim.jl with Manifolds.jl
using Manifolds, Optim, ManifoldsBase
# this is a generic part
"""
ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
Adapts Manifolds.jl manifolds for use in Optim.jl
"""
struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
M::TM
end
function Optim.retract!(M::ManifoldWrapper, x)
ManifoldsBase.embed_project!(M.M, x, x)
return x
end
function Optim.project_tangent!(M::ManifoldWrapper, g, x)
ManifoldsBase.embed_project!(M.M, g, x, g)
return g
end
# example usage of Manifolds.jl manifolds in Optim.jl
M = Manifolds.Sphere(2)
x0 = [1.0, 0.0, 0.0]
q = [0.0, 1.0, 0.0]
f(p) = 0.5 * distance(M, p, q)^2
function g!(X, p)
log!(M, X, p, q)
X .*= -1
println(p, X)
end
sol = optimize(f, g!, x0, ConjugateGradient(; manifold=ManifoldWrapper(M)))
@Affie
Copy link

Affie commented Mar 16, 2023

Thanks! I'll give it a try again. I did try at one stage but couldn't quite get it working nicely.
It looks like I just had to use embed_project, I also didn't know about the power manifold back then.
This is what I had:

struct ManifoldsVector <: Optim.Manifold
  manis::Vector{Manifold}
end


Base.getindex(mv::ManifoldsVector, inds...) = getindex(mv.mani, inds...)
Base.setindex!(mv, X, inds...) =  setindex!(mv.mani, X, inds...)


function ManifoldsVector(fg::AbstractDFG, varIds::Vector{Symbol})
  manis = Bool[]
  for k = varIds
    push!(manis, getVariableType(fg, k) |> getManifold)
  end
  ManifoldsVector(manis)
end


function Optim.retract!(manis::ManifoldsVector, x)
  for (i,M) = enumerate(manis)
    x[i] = project(M, x[i])
  end
  return x 
end
function Optim.project_tangent!(manis::ManifoldsVector, G, x)
  for (i, M) = enumerate(manis)
    G[i] = project(M, x[i], G)
  end
  return G
end

https://github.com/JuliaRobotics/IncrementalInference.jl/blob/78d5f548ad1b9539327493035830a68eda1ea438/src/Deprecated.jl#L5-L31

@mateuszbaran
Copy link
Author

Thanks! I'll give it a try again.

Cool, let me know if you have any problems with it. embed_project is a relatively new thing. I've added it primarily for compatibility with Optim.jl.

This is what I had:

Using my wrapper with product and power manifolds would be a better solution 🙂 .

@dehann
Copy link

dehann commented Mar 17, 2023

Okay thanks, so I worked in a CI unit test on IncrementalInference.jl that runs your above example (for future consistency checks) as well as this SpecialEuclidean(2) version with ManifoldDiff. We'll also do a SpecialEuclidean(3) and "partial measurement" version of this test:
(note this test fails if the coordinate 3 of q is larger than 1.5 radians)

using Manifolds, Optim, ManifoldsBase
using ManifoldDiff
# this is a generic part

using Test

##

"""
    ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
    
Adapts Manifolds.jl manifolds for use in Optim.jl
"""
struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
    M::TM
end

function Optim.retract!(M::ManifoldWrapper, x)
    ManifoldsBase.embed_project!(M.M, x, x)
    return x
end

function Optim.project_tangent!(M::ManifoldWrapper, g, x)
    ManifoldsBase.embed_project!(M.M, g, x, g)
    return g
end

r_backend = ManifoldDiff.TangentDiffBackend(
    ManifoldDiff.FiniteDifferencesBackend()
)

##

M = Manifolds.SpecialEuclidean(2)
e0 = ArrayPartition([0,0.], [1 0; 0 1.])

x0 = deepcopy(e0)
q  = exp(M,e0,hat(M,e0,randn(3)))

f(p) = distance(M, p, q)^2



## finitediff gradient (non-manual)
function g_FD!(X,p)
  X .= ManifoldDiff.gradient(M, f, p, r_backend)
  X
end

## sanity check gradients

X = hat(M, e0, zeros(3))
g_FD!(X, q)
# gradient at the optimal point should be zero
@test isapprox(0, sum(abs.(X[:])); atol=1e-8 )

# gradient not the optimal point should be non-zero
g_FD!(X, e0)
@test 0.01 < sum(abs.(X[:]))

## do optimization
x0 = deepcopy(e0)
sol = optimize(f, g_FD!, x0, ConjugateGradient(; manifold=ManifoldWrapper(M)))

sol.minimizer
@test isapprox( f(sol.minimizer), 0; atol=1e-8 )
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment