Skip to content

Instantly share code, notes, and snippets.

@redpony
Created April 16, 2012 18:17
Show Gist options
  • Save redpony/2400470 to your computer and use it in GitHub Desktop.
Save redpony/2400470 to your computer and use it in GitHub Desktop.
C++ class to represent real numbers in the log domain
#ifndef LOGVAL_H_
#define LOGVAL_H_
// represent values internally in the log domain (much larger effective range than float,
// double, or long double). useful for very small probabilities or very large unnormalized
// probabilities while avoiding underflows/overflows.
//
// this should be used as if it were a double or float, i.e. for doubles a and b:
// a * b = LogVal(a) * LogVal(b)
// a + b = LogVal(a) + LogVal(b)
//
// Placed by Chris Dyer <cdyer@cs.cmu.edu> in the public domain on April 16, 2012.
//
#include <iostream>
#include <cstdlib>
#include <cmath>
#include <limits>
#include <cassert>
template <class T>
class LogVal {
public:
typedef LogVal<T> Self;
LogVal() : s_(), v_(std::log(T())) {}
LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {}
const Self& operator=(double x) { s_ = std::signbit(x); v_ = s_ ? std::log(-x) : std::log(x); return *this; }
LogVal(double lnx,bool sign) : s_(sign),v_(lnx) {}
static Self exp(T lnx) { return Self(lnx,false); }
static Self One() { return Self(1); }
static Self Zero() { return Self(); }
static Self e() { return Self(1,false); }
void logeq(const T& v) { s_ = false; v_ = v; }
// true is negative, false is positive
bool signbit() const {
return s_;
}
Self& operator+=(const Self& a) {
if (a == Zero()) return *this;
if (a.s_ == s_) {
if (a.v_ < v_) {
v_ = v_ + log1p(std::exp(a.v_ - v_));
} else {
v_ = a.v_ + log1p(std::exp(v_ - a.v_));
}
} else {
if (a.v_ < v_) {
v_ = v_ + log1p(-std::exp(a.v_ - v_));
} else {
v_ = a.v_ + log1p(-std::exp(v_ - a.v_));
s_ = !s_;
}
}
return *this;
}
Self& operator*=(const Self& a) {
s_ = (s_ != a.s_);
v_ += a.v_;
return *this;
}
Self& operator/=(const Self& a) {
s_ = (s_ != a.s_);
v_ -= a.v_;
return *this;
}
Self& operator-=(const Self& a) {
Self b = a;
b.negate();
return *this += b;
}
friend Self abslog(Self x) {
if (x.v_<0) x.v_=-x.v_;
return x;
}
Self& poweq(const T& power) {
if (s_) {
std::cerr << "poweq(T) not implemented when s_ is true\n";
std::abort();
} else
v_ *= power;
return *this;
}
//remember, s_ means negative.
inline bool lt(Self const& o) const {
return s_==o.s_ ? v_ < o.v_ : s_ > o.s_;
}
inline bool gt(Self const& o) const {
return s_==o.s_ ? o.v_ < v_ : s_ < o.s_;
}
Self operator-() const {
return Self(v_,!s_);
}
void negate() { s_ = !s_; }
Self inverse() const { return Self(-v_,s_); }
Self pow(const T& power) const {
Self res = *this;
res.poweq(power);
return res;
}
Self root(const T& root) const {
return pow(1/root);
}
T as_float() const {
if (s_) return -std::exp(v_); else return std::exp(v_);
}
bool s_;
T v_;
};
template <class T>
inline std::ostream& operator<<(std::ostream& os, const LogVal<T>& v) {
if (v.s_) os<<"(-)";
return os<<v.v_;
}
template<class T>
LogVal<T> operator+(LogVal<T> o1, const LogVal<T>& o2) {
o1 += o2;
return o1;
}
template<class T>
LogVal<T> operator*(LogVal<T> o1, const LogVal<T>& o2) {
o1 *= o2;
return o1;
}
template<class T>
LogVal<T> operator/(LogVal<T> o1, const LogVal<T>& o2) {
o1 /= o2;
return o1;
}
template<class T>
LogVal<T> operator-(LogVal<T> o1, const LogVal<T>& o2) {
o1 -= o2;
return o1;
}
template<class T>
T log(const LogVal<T>& o) {
if (o.s_) return log(-1.0);
return o.v_;
}
template<class T>
inline bool signbit(const LogVal<T>& x) { return x.signbit(); }
template<class T>
inline LogVal<T> abs(const LogVal<T>& o) {
if (o.s_) {
LogVal<T> res = o;
res.s_ = false;
return res;
} else { return o; }
}
template <class T>
inline LogVal<T> pow(const LogVal<T>& b, const T& e) {
return b.pow(e);
}
template <class T>
bool operator==(const LogVal<T>& lhs, const LogVal<T>& rhs) {
return (lhs.v_ == rhs.v_) && (lhs.s_ == rhs.s_);
}
template <class T>
bool operator!=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
return !(lhs == rhs);
}
template <class T>
bool operator<(const LogVal<T>& lhs, const LogVal<T>& rhs) {
if (lhs.s_ == rhs.s_) {
return (lhs.v_ < rhs.v_);
} else {
return lhs.s_ > rhs.s_;
}
}
template <class T>
bool operator<=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
return (lhs < rhs) || (lhs == rhs);
}
template <class T>
bool operator>(const LogVal<T>& lhs, const LogVal<T>& rhs) {
return !(lhs <= rhs);
}
template <class T>
bool operator>=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
return !(lhs < rhs);
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment