Last active
August 18, 2018 21:49
-
-
Save DavidRdgz/873b6ed08ac44cbd482e1b84ac1c8087 to your computer and use it in GitHub Desktop.
[Rainier] Example truncating combinator on base Discrete distributions.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.stripe.rainier.core | |
import com.stripe.rainier.compute.{Evaluator, If, Real} | |
trait Discrete extends Distribution[Int] { | |
self: Discrete => | |
val emptyEvaluator = new Evaluator(Map.empty) | |
def logDensity(v: Real): Real | |
def zeroInflated(psi: Real) = | |
constantInflated(0.0, psi) | |
def constantInflated(constant: Real, psi: Real) = | |
DiscreteMixture(Map(DiscreteConstant(constant) -> psi, self -> (1 - psi))) | |
def pmf(v: Real) = | |
self.logDensity(v).exp | |
def cmf(a: Real, b: Real) = | |
(emptyEvaluator.toInt(a) to emptyEvaluator.toInt(b)) | |
.map(pmf(_)) | |
.reduce(_ + _) | |
def truncate(a: Real, b: Real) = | |
TruncatedDiscrete(a, b, self) | |
} | |
object Discrete { | |
implicit val likelihood = | |
Likelihood.from[Discrete, Int, Real] { | |
case (d, v) => d.logDensity(v) | |
} | |
} | |
/** | |
* Truncated Discrete distribution removing observations below `a` and above `b` | |
* | |
* @param a The lower bound of observations (strictly greater than) | |
* @param b The upper bound of observations (possibly equal to) | |
* @param d The base distribution to be truncated | |
*/ | |
final case class TruncatedDiscrete(a: Real, b: Real, d: Discrete) | |
extends Discrete { | |
val generator: Generator[Int] = { | |
Generator.require(Set(a, b)) { (c, n) => | |
val u = c.standardUniform | |
var lowerBound = n.toInt(a) | |
val upperBound = n.toInt(b) | |
var t = 0.0 | |
while (t <= u && t <= upperBound) { | |
val p = logDensity(lowerBound).exp | |
t += n.toDouble(p) | |
lowerBound += 1 | |
} | |
lowerBound | |
} | |
} | |
def logDensity(v: Real): Real = | |
If(v > a, | |
If(v <= b, (d.pmf(v) / d.cmf(a, b)).log, Real.negInfinity), | |
Real.negInfinity) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
scala> import com.stripe.rainier.core._ | |
import com.stripe.rainier.core._ | |
scala> import com.stripe.rainier.repl._ | |
import com.stripe.rainier.repl._ | |
scala> import com.stripe.rainier.compute._ | |
import com.stripe.rainier.compute._ | |
scala> val sales = List(4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,5,5,5,6,7,11,8,10) | |
sales: List[Int] = List(4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 6, 7, 11, 8, 10) | |
scala> val prior = for { | |
| lambda <- Normal(1, 4).param | |
| _ <- Poisson(lambda).truncate(3,13).fit(sales) | |
| } yield lambda | |
prior: com.stripe.rainier.core.RandomVariable[com.stripe.rainier.compute.Real] = com.stripe.rainier.core.RandomVariable@32cde79d | |
scala> plot1D(prior.sample()) | |
Initializing RNG with seed 1534628305626 | |
370 | | |
| ∘ | |
| ○ | |
| ○ | |
| ○·· ··○ | |
280 | ∘○ ∘○○○ ○○○∘ | |
| · · ○○○○○○○○○○○○○· | |
| ○ ○·○○○○○○○○○○○○○○ | |
| ○ ∘○○○○○○○○○○○○○○○○○○∘ ·· | |
| ○○○○○○○○○○○○○○○○○○○○○○∘○○ | |
180 | ○○○○○○○○○○○○○○○○○○○○○○○○○○· | |
| ○○○○○○○○○○○○○○○○○○○○○○○○○○○ | |
| ∘ ○○○○○○○○○○○○○○○○○○○○○○○○○○○ ·· | |
| · ·○··○○○○○○○○○○○○○○○○○○○○○○○○○○○∘○○○∘ | |
| ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ | |
90 | ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ | |
| ∘∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ·○ | |
| ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ○∘· | |
| ∘·○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ · | |
| · ∘∘·∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ○○ ·○ | |
0 |· ·· · ∘∘·○∘∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○·○○·○○∘··· | |
|--------|--------|--------|--------|--------|--------|--------|--------|-------- | |
2.55 2.85 3.14 3.43 3.73 4.02 4.31 4.61 4.90 | |
*/ | |
/* Compared to | |
scala> val prior = for { | |
| lambda <- Normal(1, 4).param | |
| _ <- Poisson(lambda).fit(sales) | |
| } yield lambda | |
prior: com.stripe.rainier.core.RandomVariable[com.stripe.rainier.compute.Real] = com.stripe.rainier.core.RandomVariable@3c9cc48e | |
scala> plot1D(prior.sample()) | |
Initializing RNG with seed 1534628604325 | |
370 | | |
| ∘ ∘∘ | |
| ○○·○○∘ | |
| ∘· ○○○○○○ | |
| ○○ ○○○○○○○·· | |
280 | ○○○○○○○○○○○○ | |
| ○○○○○○○○○○○○○○ ∘ | |
| · ○○○○○○○○○○○○○○ ○○ | |
| ··○ ∘○○○○○○○○○○○○○○ ○○· | |
| ○○○ ○○○○○○○○○○○○○○○∘○○○○ ∘ | |
180 | ○○○○○○○○○○○○○○○○○○○○○○○○ ○ | |
| · ○○○○○○○○○○○○○○○○○○○○○○○○○○ | |
| ○ ○○○○○○○○○○○○○○○○○○○○○○○○○○ ·∘∘ | |
| ·○∘·○○○○○○○○○○○○○○○○○○○○○○○○○○∘○○○· | |
| ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ | |
90 | ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○∘ · | |
| ·○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○∘ ○ | |
| ∘○○∘∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ○ · | |
| ∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○∘○∘○· ○ | |
| · ·· ○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○ ○·∘· ○ | |
0 |·∘·· ·○∘∘○○○∘○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○○∘○·○ ·· ·○ | |
|--------|--------|--------|--------|--------|--------|--------|--------|-------- | |
3.57 3.84 4.11 4.38 4.65 4.92 5.19 5.46 5.73 | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment