Last active
January 16, 2018 09:01
-
-
Save 2bbb/37152a5fa323308965c2223bf696bc1b to your computer and use it in GitHub Desktop.
autodiff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cmath> | |
namespace bbb { | |
namespace autodiff { | |
template <typename type> | |
struct dual { | |
inline static dual make(type x) { return dual(x); }; | |
inline static dual scalar(type x) { return dual(x, 0.0); }; | |
constexpr dual(type x, type dx = type(1.0)) noexcept | |
: x(x), dx(dx) {}; | |
constexpr dual() noexcept = default; | |
constexpr dual(const dual &) noexcept = default; | |
constexpr dual(dual &&) noexcept = default; | |
inline dual &operator=(const dual &) noexcept = default; | |
inline dual &operator=(dual &&) noexcept = default; | |
inline dual &operator=(const type &r) noexcept { return operator=(dual(r)); }; | |
inline dual &operator=(type &&r) noexcept { return operator=(dual(std::move(r))); }; | |
inline dual &operator+=(const dual &r) noexcept { | |
x += r.x; | |
dx += r.dx; | |
return *this; | |
}; | |
inline constexpr dual operator+(const dual &r) const noexcept { return dual(x + r.x, dx + r.dx); }; | |
friend inline constexpr dual operator+(const type &v, const dual &r) noexcept { return scalar(v) + r; }; | |
friend inline constexpr dual operator+(type &&v, const dual &r) noexcept { return scalar(std::move(v)) + r; }; | |
inline dual &operator-=(const dual &r) noexcept { | |
x -= r.x; | |
dx -= r.dx; | |
return *this; | |
}; | |
inline constexpr dual operator-(const dual &r) const noexcept { return dual(x - r.x, dx - r.dx); }; | |
friend inline constexpr dual operator-(const type &v, const dual &r) noexcept { return dual(v, type(0.0)) - r; }; | |
friend inline constexpr dual operator-(type &&v, const dual &r) noexcept { return dual(std::move(v), type(0.0)) - r; }; | |
inline dual operator*=(const dual &r) noexcept { | |
dx = dx * r.x + x * r.dx; | |
x *= r.x; | |
return *this; | |
}; | |
inline constexpr dual operator*(const dual &r) const noexcept { return dual(x * r.x, dx * r.x + x * r.dx); }; | |
friend inline constexpr dual operator*(const type &v, const dual &r) noexcept { return dual(v) * r; }; | |
friend inline constexpr dual operator*(type &&v, const dual &r) noexcept { return dual(std::move(v)) * r; }; | |
inline dual operator/=(const dual &r) noexcept { | |
dx = (dx / r.x - x * r.dx) / (r.x * r.x); | |
x /= r.x; | |
return *this; | |
}; | |
inline constexpr dual operator/(const dual &r) const noexcept { return dual(x / r.x, (dx / r.x - x * r.dx) / (r.x * r.x)); }; | |
friend inline constexpr dual operator/(const type &v, const dual &r) noexcept { return dual(v) / r; }; | |
friend inline constexpr dual operator/(type &&v, const dual &r) noexcept { return dual(std::move(v)) / r; }; | |
type x, dx; | |
}; | |
}; | |
namespace math { | |
template <typename type> | |
constexpr autodiff::dual<type> sin(const autodiff::dual<type> &v) { | |
using std::sin; | |
using std::cos; | |
return autodiff::dual<type>(sin(v.x), cos(v.x) * v.dx); | |
} | |
template <typename type> | |
constexpr type sin(const type &v) { | |
using std::sin; | |
return sin(v); | |
} | |
template <typename type> | |
constexpr autodiff::dual<type> cos(const autodiff::dual<type> &v) { | |
using std::sin; | |
using std::cos; | |
return autodiff::dual<type>(cos(v.x), -sin(v.x) * v.dx); | |
} | |
template <typename type> | |
constexpr type cos(const type &v) { | |
using std::cos; | |
return cos(v); | |
} | |
template <typename type> | |
constexpr autodiff::dual<type> tan(const autodiff::dual<type> &v) { | |
using std::tan; | |
using std::cos; | |
return autodiff::dual<type>(tan(v.x), v.dx / (cos(v.x) * cos(v.x))); | |
} | |
template <typename type> | |
constexpr type tan(const type &v) { | |
using std::tan; | |
return tan(v); | |
} | |
template <typename type> | |
constexpr autodiff::dual<type> exp(const autodiff::dual<type> &v) { | |
using std::exp; | |
return autodiff::dual<type>(exp(v.x), type(0.0)) * autodiff::dual<type>(type(1.0), v.dx); | |
} | |
template <typename type> | |
constexpr type exp(const type &v) { | |
using std::exp; | |
return exp(v); | |
} | |
template <typename type> | |
constexpr autodiff::dual<type> log(const autodiff::dual<type> &v) { | |
using std::log; | |
return autodiff::dual<type>(log(v.x), v.dx / v.x); | |
} | |
template <typename type> | |
constexpr type log(const type &v) { | |
using std::log; | |
return log(v); | |
} | |
}; | |
using namespace math; | |
}; | |
#include <iostream> | |
int main(int argc, char *argv[]) { | |
using ddual = bbb::autodiff::dual<double>; | |
using dddual = bbb::autodiff::dual<ddual>; | |
auto x = ddual::make(4); | |
auto y = x * x; | |
std::cout << y.x << ", " << y.dx << std::endl; | |
auto z = bbb::exp(x); | |
std::cout << x.x << ", " << x.dx << std::endl; | |
std::cout << z.x << ", " << z.dx << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment