Skip to content

Instantly share code, notes, and snippets.

@bullno1
Created November 10, 2011 17:11
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bullno1/1355442 to your computer and use it in GitHub Desktop.
Save bullno1/1355442 to your computer and use it in GitHub Desktop.
Dual number and auto differentiation in C++
#include <iostream>
template<typename Scalar>
class DualNumber
{
public:
inline DualNumber(const Scalar& realPart, const Scalar& dualPart = Scalar())
:mReal(realPart)
,mDual(dualPart)
{
}
inline Scalar getReal() const {return mReal;}
inline Scalar getDual() const {return mDual;}
private:
Scalar mReal;
Scalar mDual;
};
template<typename Scalar>
inline DualNumber<Scalar> makeDualNumber(const Scalar& realPart, const Scalar& dualPart = Scalar())
{
return DualNumber<Scalar>(realPart, dualPart);
}
template<typename Scalar>
inline DualNumber<Scalar> operator+(const DualNumber<Scalar>& lhs, const DualNumber<Scalar>& rhs)
{
Scalar r1 = lhs.getReal();
Scalar d1 = lhs.getDual();
Scalar r2 = rhs.getReal();
Scalar d2 = rhs.getDual();
return makeDualNumber(r1 + r2, d1 + d2);
}
template<typename Scalar>
inline DualNumber<Scalar> operator*(const DualNumber<Scalar>& lhs, const DualNumber<Scalar>& rhs)
{
Scalar r1 = lhs.getReal();
Scalar d1 = lhs.getDual();
Scalar r2 = rhs.getReal();
Scalar d2 = rhs.getDual();
return makeDualNumber(r1 * r2, r1 * d2 + r2 * d1);
}
template<typename Scalar>
inline DualNumber<Scalar> operator-(const DualNumber<Scalar>& lhs, const DualNumber<Scalar>& rhs)
{
Scalar r1 = lhs.getReal();
Scalar d1 = lhs.getDual();
Scalar r2 = rhs.getReal();
Scalar d2 = rhs.getDual();
return makeDualNumber(r1 * r2, r1 * d2 + r2 * d1);
}
template<typename Scalar>
inline DualNumber<Scalar> operator/(const DualNumber<Scalar>& lhs, const DualNumber<Scalar>& rhs)
{
Scalar r1 = lhs.getReal();
Scalar d1 = lhs.getDual();
Scalar r2 = rhs.getReal();
Scalar d2 = rhs.getDual();
return makeDualNumber(r1 / r2, (d1 * r2 - r1 * d2) / (r2 * r2));
}
template<typename Scalar>
inline DualNumber<Scalar> operator-(const DualNumber<Scalar>& operand)
{
return makeDualNumber(- operand.getReal(), - operand.getDual())
}
template<typename Scalar>
Scalar testFunc(Scalar x)
{
return (x * x) / (x + Scalar(1));
}
int main()
{
std::cout<< testFunc<DualNumber<float>>(makeDualNumber(2.f, 1.f)).getDual() << std::endl;
std::cin.get();
}
@carlosjoserg
Copy link

Hi, I found this useful.

I found a missing semi-colon here, and here, there is a >> after float that should be > >.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment