Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@DavidRdgz
Last active August 18, 2018 21:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DavidRdgz/873b6ed08ac44cbd482e1b84ac1c8087 to your computer and use it in GitHub Desktop.
Save DavidRdgz/873b6ed08ac44cbd482e1b84ac1c8087 to your computer and use it in GitHub Desktop.
[Rainier] Example truncating combinator on base Discrete distributions.
package com.stripe.rainier.core
import com.stripe.rainier.compute.{Evaluator, If, Real}
trait Discrete extends Distribution[Int] {
self: Discrete =>
val emptyEvaluator = new Evaluator(Map.empty)
def logDensity(v: Real): Real
def zeroInflated(psi: Real) =
constantInflated(0.0, psi)
def constantInflated(constant: Real, psi: Real) =
DiscreteMixture(Map(DiscreteConstant(constant) -> psi, self -> (1 - psi)))
def pmf(v: Real) =
self.logDensity(v).exp
def cmf(a: Real, b: Real) =
(emptyEvaluator.toInt(a) to emptyEvaluator.toInt(b))
.map(pmf(_))
.reduce(_ + _)
def truncate(a: Real, b: Real) =
TruncatedDiscrete(a, b, self)
}
object Discrete {
implicit val likelihood =
Likelihood.from[Discrete, Int, Real] {
case (d, v) => d.logDensity(v)
}
}
/**
* Truncated Discrete distribution removing observations below `a` and above `b`
*
* @param a The lower bound of observations (strictly greater than)
* @param b The upper bound of observations (possibly equal to)
* @param d The base distribution to be truncated
*/
final case class TruncatedDiscrete(a: Real, b: Real, d: Discrete)
extends Discrete {
val generator: Generator[Int] = {
Generator.require(Set(a, b)) { (c, n) =>
val u = c.standardUniform
var lowerBound = n.toInt(a)
val upperBound = n.toInt(b)
var t = 0.0
while (t <= u && t <= upperBound) {
val p = logDensity(lowerBound).exp
t += n.toDouble(p)
lowerBound += 1
}
lowerBound
}
}
def logDensity(v: Real): Real =
If(v > a,
If(v <= b, (d.pmf(v) / d.cmf(a, b)).log, Real.negInfinity),
Real.negInfinity)
}
/*
scala> import com.stripe.rainier.core._
import com.stripe.rainier.core._
scala> import com.stripe.rainier.repl._
import com.stripe.rainier.repl._
scala> import com.stripe.rainier.compute._
import com.stripe.rainier.compute._
scala> val sales = List(4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,5,5,5,6,7,11,8,10)
sales: List[Int] = List(4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 6, 7, 11, 8, 10)
scala> val prior = for {
| lambda <- Normal(1, 4).param
| _ <- Poisson(lambda).truncate(3,13).fit(sales)
| } yield lambda
prior: com.stripe.rainier.core.RandomVariable[com.stripe.rainier.compute.Real] = com.stripe.rainier.core.RandomVariable@32cde79d
scala> plot1D(prior.sample())
Initializing RNG with seed 1534628305626
370 |
| ∘
| ○
| ○
| ○·· ··○
280 | ∘○ ∘○○○ ○○○∘
| · · ○○○○○○○○○○○○○·
| ○ ○·○○○○○○○○○○○○○○
| ○ ∘○○○○○○○○○○○○○○○○○○∘ ··
| ○○○○○○○○○○○○○○○○○○○○○○∘○○
180 | ○○○○○○○○○○○○○○○○○○○○○○○○○○·
| ○○○○○○○○○○○○○○○○○○○○○○○○○○○
| ∘ ○○○○○○○○○○○○○○○○○○○○○○○○○○○ ··
| · ·○··○○○○○○○○○○○○○○○○○○○○○○○○○○○∘○○○∘
| ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○
90 | ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○
| ∘∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ·○
| ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ○∘·
| ∘·○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ·
| · ∘∘·∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ○○ ·○
0 |· ·· · ∘∘·○∘∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○·○○·○○∘···
|--------|--------|--------|--------|--------|--------|--------|--------|--------
2.55 2.85 3.14 3.43 3.73 4.02 4.31 4.61 4.90
*/
/* Compared to
scala> val prior = for {
| lambda <- Normal(1, 4).param
| _ <- Poisson(lambda).fit(sales)
| } yield lambda
prior: com.stripe.rainier.core.RandomVariable[com.stripe.rainier.compute.Real] = com.stripe.rainier.core.RandomVariable@3c9cc48e
scala> plot1D(prior.sample())
Initializing RNG with seed 1534628604325
370 |
| ∘ ∘∘
| ○○·○○∘
| ∘· ○○○○○○
| ○○ ○○○○○○○··
280 | ○○○○○○○○○○○○
| ○○○○○○○○○○○○○○ ∘
| · ○○○○○○○○○○○○○○ ○○
| ··○ ∘○○○○○○○○○○○○○○ ○○·
| ○○○ ○○○○○○○○○○○○○○○∘○○○○ ∘
180 | ○○○○○○○○○○○○○○○○○○○○○○○○ ○
| · ○○○○○○○○○○○○○○○○○○○○○○○○○○
| ○ ○○○○○○○○○○○○○○○○○○○○○○○○○○ ·∘∘
| ·○∘·○○○○○○○○○○○○○○○○○○○○○○○○○○∘○○○·
| ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○
90 | ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○∘ ·
| ·○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○∘ ○
| ∘○○∘∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ○ ·
| ∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○∘○∘○· ○
| · ·· ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ○·∘· ○
0 |·∘·· ·○∘∘○○○∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○∘○·○ ·· ·○
|--------|--------|--------|--------|--------|--------|--------|--------|--------
3.57 3.84 4.11 4.38 4.65 4.92 5.19 5.46 5.73
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment