Dual number and auto differentiation in C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
template<typename Scalar> | |
class DualNumber | |
{ | |
public: | |
inline DualNumber(const Scalar& realPart, const Scalar& dualPart = Scalar()) | |
:mReal(realPart) | |
,mDual(dualPart) | |
{ | |
} | |
inline Scalar getReal() const {return mReal;} | |
inline Scalar getDual() const {return mDual;} | |
private: | |
Scalar mReal; | |
Scalar mDual; | |
}; | |
template<typename Scalar> | |
inline DualNumber<Scalar> makeDualNumber(const Scalar& realPart, const Scalar& dualPart = Scalar()) | |
{ | |
return DualNumber<Scalar>(realPart, dualPart); | |
} | |
template<typename Scalar> | |
inline DualNumber<Scalar> operator+(const DualNumber<Scalar>& lhs, const DualNumber<Scalar>& rhs) | |
{ | |
Scalar r1 = lhs.getReal(); | |
Scalar d1 = lhs.getDual(); | |
Scalar r2 = rhs.getReal(); | |
Scalar d2 = rhs.getDual(); | |
return makeDualNumber(r1 + r2, d1 + d2); | |
} | |
template<typename Scalar> | |
inline DualNumber<Scalar> operator*(const DualNumber<Scalar>& lhs, const DualNumber<Scalar>& rhs) | |
{ | |
Scalar r1 = lhs.getReal(); | |
Scalar d1 = lhs.getDual(); | |
Scalar r2 = rhs.getReal(); | |
Scalar d2 = rhs.getDual(); | |
return makeDualNumber(r1 * r2, r1 * d2 + r2 * d1); | |
} | |
template<typename Scalar> | |
inline DualNumber<Scalar> operator-(const DualNumber<Scalar>& lhs, const DualNumber<Scalar>& rhs) | |
{ | |
Scalar r1 = lhs.getReal(); | |
Scalar d1 = lhs.getDual(); | |
Scalar r2 = rhs.getReal(); | |
Scalar d2 = rhs.getDual(); | |
return makeDualNumber(r1 * r2, r1 * d2 + r2 * d1); | |
} | |
template<typename Scalar> | |
inline DualNumber<Scalar> operator/(const DualNumber<Scalar>& lhs, const DualNumber<Scalar>& rhs) | |
{ | |
Scalar r1 = lhs.getReal(); | |
Scalar d1 = lhs.getDual(); | |
Scalar r2 = rhs.getReal(); | |
Scalar d2 = rhs.getDual(); | |
return makeDualNumber(r1 / r2, (d1 * r2 - r1 * d2) / (r2 * r2)); | |
} | |
template<typename Scalar> | |
inline DualNumber<Scalar> operator-(const DualNumber<Scalar>& operand) | |
{ | |
return makeDualNumber(- operand.getReal(), - operand.getDual()) | |
} | |
template<typename Scalar> | |
Scalar testFunc(Scalar x) | |
{ | |
return (x * x) / (x + Scalar(1)); | |
} | |
int main() | |
{ | |
std::cout<< testFunc<DualNumber<float>>(makeDualNumber(2.f, 1.f)).getDual() << std::endl; | |
std::cin.get(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, I found this useful.
I found a missing semi-colon here, and here, there is a
>>
after float that should be> >
.