Skip to content

Instantly share code, notes, and snippets.

@jeetsukumaran
Forked from sftrabbit/beta_distribution.hpp
Created March 8, 2013 16:11
Show Gist options
  • Save jeetsukumaran/5117563 to your computer and use it in GitHub Desktop.
Save jeetsukumaran/5117563 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <sstream>
#include <string>
#include <random>
namespace sftrabbit {
template <typename RealType = double>
class beta_distribution
{
public:
typedef RealType result_type;
class param_type
{
public:
typedef beta_distribution distribution_type;
explicit param_type(RealType a = 2.0, RealType b = 2.0)
: a_param(a), b_param(b) { }
RealType a() const { return a_param; }
RealType b() const { return b_param; }
bool operator==(const param_type& other) const
{
return (a_param == other.a_param &&
b_param == other.b_param);
}
bool operator!=(const param_type& other) const
{
return !(*this == other);
}
private:
RealType a_param, b_param;
};
explicit beta_distribution(RealType a = 2.0, RealType b = 2.0)
: a_gamma(a), b_gamma(b) { }
explicit beta_distribution(const param_type& param)
: a_gamma(param.a()), b_gamma(param.b()) { }
void reset() { }
param_type param() const
{
return param_type(a(), b());
}
void param(const param_type& param)
{
a_gamma = gamma_dist_type(param.a());
b_gamma = gamma_dist_type(param.b());
}
template <typename URNG>
result_type operator()(URNG& engine)
{
return generate(engine, a_gamma, b_gamma);
}
template <typename URNG>
result_type operator()(URNG& engine, const param_type& param)
{
gamma_dist_type a_param_gamma(param.a()),
b_param_gamma(param.b());
return generate(engine, a_param_gamma, b_param_gamma);
}
result_type min() const { return 0.0; }
result_type max() const { return 1.0; }
RealType a() const { return a_gamma.alpha(); }
RealType b() const { return b_gamma.alpha(); }
bool operator==(const beta_distribution<result_type>& other) const
{
return (param() == other.param() &&
a_gamma == other.a_gamma &&
b_gamma == other.b_gamma);
}
bool operator!=(const beta_distribution<result_type>& other) const
{
return !(*this == other);
}
private:
typedef std::gamma_distribution<result_type> gamma_dist_type;
gamma_dist_type a_gamma, b_gamma;
template <typename URNG>
result_type generate(URNG& engine,
gamma_dist_type& x_gamma,
gamma_dist_type& y_gamma)
{
result_type x = x_gamma(engine);
return x / (x + y_gamma(engine));
}
};
template <typename CharT, typename RealType>
std::basic_ostream<CharT>& operator<<(std::basic_ostream<CharT>& os,
const beta_distribution<RealType>& beta)
{
os << "~Beta(" << beta.a() << "," << beta.b() << ")";
return os;
}
template <typename CharT, typename RealType>
std::basic_istream<CharT>& operator>>(std::basic_istream<CharT>& is,
beta_distribution<RealType>& beta)
{
std::string str;
RealType a, b;
if (std::getline(is, str, '(') && str == "~Beta" &&
is >> a && is.get() == ',' && is >> b && is.get() == ')') {
beta = beta_distribution<RealType>(a, b);
} else {
is.setstate(std::ios::failbit);
}
return is;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment