Last active
June 1, 2022 10:36
-
-
Save scturtle/fec3a544d3fc0c1aa2c82a328698c713 to your computer and use it in GitHub Desktop.
Levenberg-Marquardt algorithm with Eigen.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <glog/logging.h> | |
#include <unsupported/Eigen/AutoDiff> | |
#include <unsupported/Eigen/LevenbergMarquardt> | |
template <typename T> | |
Eigen::Matrix<T, Eigen::Dynamic, 1> | |
func(const Eigen::Matrix<T, Eigen::Dynamic, 1> &xs, | |
const Eigen::Matrix<T, 3, 1> &x) | |
{ | |
auto xsa = xs.array(); | |
return x[0] * (-(xsa - x[1]).pow(2) / (2 * x[2] * x[2])).exp(); | |
} | |
template <typename T> struct Functor : public Eigen::DenseFunctor<T> | |
{ | |
const Eigen::VectorXf &xs; | |
const Eigen::VectorXf &ys; | |
Functor(const Eigen::VectorXf &xs, const Eigen::VectorXf &ys) | |
: Eigen::DenseFunctor<T>(/*inputs=*/3, /*values=*/xs.rows()), xs(xs), | |
ys(ys) | |
{} | |
template <typename T1> | |
int operator()(const Eigen::Matrix<T1, Eigen::Dynamic, 1> &x, | |
Eigen::Matrix<T1, Eigen::Dynamic, 1> &fvec) const | |
{ | |
fvec = func<T1>(xs, x) - ys; | |
return 0; | |
} | |
int df(const Eigen::Vector3f &x, Eigen::MatrixXf &jac) const | |
{ | |
using Scalar = Eigen::AutoDiffScalar<Eigen::VectorXf>; | |
using ScalarVector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>; | |
ScalarVector ax = x.template cast<Scalar>(), av(this->values()); | |
for (int j = 0; j < this->values(); ++j) | |
av[j].derivatives().resize(this->inputs()); | |
for (int i = 0; i < this->inputs(); ++i) | |
ax[i].derivatives() = Eigen::Vector3f::Unit(this->inputs(), i); | |
operator()(ax, av); | |
for (int i = 0; i < this->values(); ++i) | |
// fvec[i] = av[i].value(); | |
jac.row(i) = av[i].derivatives(); | |
return 0; | |
} | |
}; | |
int main() | |
{ | |
Eigen::Vector3f x; | |
x << 1, 2, 3; | |
Eigen::VectorXf xs = Eigen::ArrayXf::LinSpaced(11, -3, 3); | |
Eigen::VectorXf ys = func<float>(xs, x); | |
Functor<float> f(xs, ys); | |
Eigen::LevenbergMarquardt<Functor<float>> lm(f); | |
Eigen::VectorXf p0(3); | |
p0 << 1.1, 2.2, 3.3; | |
LOG(INFO) << "status: " << lm.minimize(p0); | |
LOG(INFO) << "result: " << p0.transpose(); | |
LOG(INFO) << "diff: " << (p0 - x).transpose(); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <unsupported/Eigen/LevenbergMarquardt> | |
#include <unsupported/Eigen/NumericalDiff> | |
#include <glog/logging.h> | |
Eigen::VectorXf func(const Eigen::VectorXf &xs, const Eigen::Vector3f &x) | |
{ | |
auto xsa = xs.array(); | |
return x[0] * xsa * xsa + x[1] * xsa + x[2]; | |
} | |
struct Functor : public Eigen::DenseFunctor<float> | |
{ | |
const Eigen::VectorXf &xs; | |
const Eigen::VectorXf &ys; | |
Functor(const Eigen::VectorXf &xs, const Eigen::VectorXf &ys) | |
: DenseFunctor(/*inputs=*/3, /*values=*/xs.rows()), xs(xs), ys(ys) | |
{} | |
int operator()(const Eigen::Vector3f &x, Eigen::VectorXf &fvec) const | |
{ | |
fvec = func(xs, x) - ys; | |
return 0; | |
} | |
int df(const Eigen::Vector3f &, Eigen::MatrixXf &fjac) const | |
{ | |
fjac.col(0) = xs.array().pow(2); | |
fjac.col(1) = xs; | |
fjac.col(2) = Eigen::VectorXf::Ones(values()); | |
return 0; | |
} | |
}; | |
int main() | |
{ | |
Eigen::Vector3f x; | |
x << 1, 2, 3; | |
Eigen::VectorXf xs = Eigen::ArrayXf::LinSpaced(11, -3, 3); | |
Eigen::VectorXf ys = func(xs, x); | |
Functor f(xs, ys); | |
#if 0 | |
Eigen::NumericalDiff<Functor> numdiff(f); | |
Eigen::LevenbergMarquardt<Eigen::NumericalDiff<Functor>> lm(numdiff); | |
#else | |
Eigen::LevenbergMarquardt<Functor> lm(f); | |
#endif | |
Eigen::VectorXf p0(3); | |
p0 << 1.1, 2.2, 3.3; | |
LOG(INFO) << "status: " << lm.minimize(p0); | |
LOG(INFO) << "result: " << p0.transpose(); | |
LOG(INFO) << "diff: " << (p0 - x).transpose(); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <glog/logging.h> | |
#include <unsupported/Eigen/LevenbergMarquardt> | |
#include <unsupported/Eigen/NumericalDiff> | |
Eigen::VectorXf func(const Eigen::VectorXf &xs, const Eigen::VectorXf &x) | |
{ | |
// clang-format off | |
return x[0] * (-(xs.array() - x[1]).pow(2) / (2 * x[2] * x[2])).exp() + x[3]; | |
// clang-format on | |
} | |
struct Functor : public Eigen::DenseFunctor<float> | |
{ | |
const Eigen::VectorXf &xs; | |
const Eigen::VectorXf &ys; | |
Functor(const Eigen::VectorXf &xs, const Eigen::VectorXf &ys) | |
: DenseFunctor(/*inputs=*/4, /*values=*/xs.rows()), xs(xs), ys(ys) | |
{} | |
int operator()(const Eigen::VectorXf &x, Eigen::VectorXf &fvec) const | |
{ | |
fvec = func(xs, x) - ys; | |
return 0; | |
} | |
int df(const Eigen::VectorXf &x, Eigen::MatrixXf &fjac) const | |
{ | |
Eigen::VectorXf es = func(xs, x).array() - x[3]; | |
// clang-format off | |
fjac.col(0) = es / x[0]; | |
fjac.col(1) = es.array() * (xs.array() - x[1]) / (x[2] * x[2]); | |
fjac.col(2) = es.array() * (xs.array() - x[1]).pow(2) / (x[2] * x[2] * x[2]); | |
fjac.col(3) = Eigen::VectorXf::Ones(values()); | |
// clang-format on | |
return 0; | |
} | |
}; | |
int main() | |
{ | |
Eigen::Vector4f x; | |
x << 1, 0.1, 1, 0.5; | |
Eigen::VectorXf xs = Eigen::ArrayXf::LinSpaced(1001, -3, 3); | |
Eigen::VectorXf ys = func(xs, x); | |
Functor f(xs, ys); | |
#if 0 | |
Eigen::NumericalDiff<Functor> numdiff(f); | |
Eigen::LevenbergMarquardt<Eigen::NumericalDiff<Functor>> lm(numdiff); | |
#else | |
Eigen::LevenbergMarquardt<Functor> lm(f); | |
#endif | |
Eigen::VectorXf p0(4); | |
p0 << 1.1, 0.15, 1.3, 0.2; | |
LOG(INFO) << "status: " << lm.minimize(p0); | |
LOG(INFO) << "result: " << p0.transpose(); | |
LOG(INFO) << "diff: " << (p0 - x).transpose(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment