Skip to content

Instantly share code, notes, and snippets.

@2bbb
Last active January 16, 2018 09:01
Show Gist options
  • Save 2bbb/37152a5fa323308965c2223bf696bc1b to your computer and use it in GitHub Desktop.
Save 2bbb/37152a5fa323308965c2223bf696bc1b to your computer and use it in GitHub Desktop.
autodiff
#include <cmath>
namespace bbb {
namespace autodiff {
template <typename type>
struct dual {
inline static dual make(type x) { return dual(x); };
inline static dual scalar(type x) { return dual(x, 0.0); };
constexpr dual(type x, type dx = type(1.0)) noexcept
: x(x), dx(dx) {};
constexpr dual() noexcept = default;
constexpr dual(const dual &) noexcept = default;
constexpr dual(dual &&) noexcept = default;
inline dual &operator=(const dual &) noexcept = default;
inline dual &operator=(dual &&) noexcept = default;
inline dual &operator=(const type &r) noexcept { return operator=(dual(r)); };
inline dual &operator=(type &&r) noexcept { return operator=(dual(std::move(r))); };
inline dual &operator+=(const dual &r) noexcept {
x += r.x;
dx += r.dx;
return *this;
};
inline constexpr dual operator+(const dual &r) const noexcept { return dual(x + r.x, dx + r.dx); };
friend inline constexpr dual operator+(const type &v, const dual &r) noexcept { return scalar(v) + r; };
friend inline constexpr dual operator+(type &&v, const dual &r) noexcept { return scalar(std::move(v)) + r; };
inline dual &operator-=(const dual &r) noexcept {
x -= r.x;
dx -= r.dx;
return *this;
};
inline constexpr dual operator-(const dual &r) const noexcept { return dual(x - r.x, dx - r.dx); };
friend inline constexpr dual operator-(const type &v, const dual &r) noexcept { return dual(v, type(0.0)) - r; };
friend inline constexpr dual operator-(type &&v, const dual &r) noexcept { return dual(std::move(v), type(0.0)) - r; };
inline dual operator*=(const dual &r) noexcept {
dx = dx * r.x + x * r.dx;
x *= r.x;
return *this;
};
inline constexpr dual operator*(const dual &r) const noexcept { return dual(x * r.x, dx * r.x + x * r.dx); };
friend inline constexpr dual operator*(const type &v, const dual &r) noexcept { return dual(v) * r; };
friend inline constexpr dual operator*(type &&v, const dual &r) noexcept { return dual(std::move(v)) * r; };
inline dual operator/=(const dual &r) noexcept {
dx = (dx / r.x - x * r.dx) / (r.x * r.x);
x /= r.x;
return *this;
};
inline constexpr dual operator/(const dual &r) const noexcept { return dual(x / r.x, (dx / r.x - x * r.dx) / (r.x * r.x)); };
friend inline constexpr dual operator/(const type &v, const dual &r) noexcept { return dual(v) / r; };
friend inline constexpr dual operator/(type &&v, const dual &r) noexcept { return dual(std::move(v)) / r; };
type x, dx;
};
};
namespace math {
template <typename type>
constexpr autodiff::dual<type> sin(const autodiff::dual<type> &v) {
using std::sin;
using std::cos;
return autodiff::dual<type>(sin(v.x), cos(v.x) * v.dx);
}
template <typename type>
constexpr type sin(const type &v) {
using std::sin;
return sin(v);
}
template <typename type>
constexpr autodiff::dual<type> cos(const autodiff::dual<type> &v) {
using std::sin;
using std::cos;
return autodiff::dual<type>(cos(v.x), -sin(v.x) * v.dx);
}
template <typename type>
constexpr type cos(const type &v) {
using std::cos;
return cos(v);
}
template <typename type>
constexpr autodiff::dual<type> tan(const autodiff::dual<type> &v) {
using std::tan;
using std::cos;
return autodiff::dual<type>(tan(v.x), v.dx / (cos(v.x) * cos(v.x)));
}
template <typename type>
constexpr type tan(const type &v) {
using std::tan;
return tan(v);
}
template <typename type>
constexpr autodiff::dual<type> exp(const autodiff::dual<type> &v) {
using std::exp;
return autodiff::dual<type>(exp(v.x), type(0.0)) * autodiff::dual<type>(type(1.0), v.dx);
}
template <typename type>
constexpr type exp(const type &v) {
using std::exp;
return exp(v);
}
template <typename type>
constexpr autodiff::dual<type> log(const autodiff::dual<type> &v) {
using std::log;
return autodiff::dual<type>(log(v.x), v.dx / v.x);
}
template <typename type>
constexpr type log(const type &v) {
using std::log;
return log(v);
}
};
using namespace math;
};
#include <iostream>
int main(int argc, char *argv[]) {
using ddual = bbb::autodiff::dual<double>;
using dddual = bbb::autodiff::dual<ddual>;
auto x = ddual::make(4);
auto y = x * x;
std::cout << y.x << ", " << y.dx << std::endl;
auto z = bbb::exp(x);
std::cout << x.x << ", " << x.dx << std::endl;
std::cout << z.x << ", " << z.dx << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment