Skip to content

Instantly share code, notes, and snippets.

@ejconlon
Created May 24, 2019 16:46
Show Gist options
  • Save ejconlon/cafad644948388e1807e7b68508be0bb to your computer and use it in GitHub Desktop.
Save ejconlon/cafad644948388e1807e7b68508be0bb to your computer and use it in GitHub Desktop.
import scala.annotation.tailrec
import scala.collection.immutable.SortedMap
import scala.collection.mutable
import scala.language.higherKinds
import scala.util.Random
import scala.util.control.TailCalls.{TailRec, done}
sealed trait KArrow[F[_], A, B] extends Product with Serializable
object KArrow {
final case class KMap[F[_], A, B](f: A => B) extends KArrow[F, A, B]
final case class KFlatMap[F[_], A, B](f: A => F[B]) extends KArrow[F, A, B]
}
trait TransBind[F[_], G[_]] {
def apply[A, B](what: G[A], arr: KArrow[F, A, B]): TailRec[G[B]]
}
object TAL {
def apply[F[_], A, B](arr: KArrow[F, A, B]): TAL[F, A, B] =
new TAL(IndexedSeq(arr))
private def consume[F[_], G[_], A, B](
what: G[A],
transBind: TransBind[F, G],
arrIt: Iterator[KArrow[F, _, _]]
): TailRec[G[B]] =
if (arrIt.hasNext) {
val arr = arrIt.next.asInstanceOf[KArrow[F, A, B]]
transBind(what, arr).flatMap { whatNow =>
consume(whatNow, transBind, arrIt)
}
} else {
done(what.asInstanceOf[G[B]])
}
}
// Severely cheating at a type-aligned list thanks to erasure
final class TAL[F[_], A, B](val arrs: IndexedSeq[KArrow[F, _, _]]) extends AnyVal {
import TAL._
def +:[Z](arr: KArrow[F, Z, A]): TAL[F, Z, B] = new TAL(arr +: arrs)
def :+[C](arr: KArrow[F, B, C]): TAL[F, A, C] = new TAL(arrs :+ arr)
def run[G[_]](what: G[A], transBind: TransBind[F, G]): TailRec[G[B]] =
consume(what, transBind, arrs.iterator)
}
// TODO make this a log-prob for accuracy, or rational etc
final class Prob(val value: Double) extends AnyVal
sealed trait Dist[A] extends Product with Serializable {
import Dist._
import KArrow._
final def map[B](f: A => B): Dist[B] = appendArr(KMap(f))
final def flatMap[B](f: A => Dist[B]): Dist[B] = appendArr(KFlatMap(f))
private[this] def appendArr[B](arr: KArrow[Dist, A, B]): Dist[B] =
this match {
case Bind(c, tal) => Bind(c, tal :+ arr)
case _ => Bind(this, TAL(arr))
}
protected def sampleTailRec(random: Random): TailRec[A]
protected def supportTailRec: TailRec[Set[A]]
final def sample(random: Random): A = sampleTailRec(random).result
final def support: Set[A] = supportTailRec.result
}
object Dist {
import KArrow._
case object EmptyException extends Exception("empty distribution")
private final case class Pure[A](value: A) extends Dist[A] {
override protected def sampleTailRec(random: Random): TailRec[A] = done(value)
override protected def supportTailRec: TailRec[Set[A]] = done(Set(value))
}
private[this] type Identity[A] = A
private[this] final class SampleTransBind(random: Random) extends TransBind[Dist, Identity] {
override def apply[A, B](
what: Identity[A],
arr: KArrow[Dist, A, B]
): TailRec[Identity[B]] =
arr match {
case KMap(f) => done(f(what))
case KFlatMap(f) =>
val d = f(what)
d.sampleTailRec(random)
}
}
private[this] object SupportTransBind extends TransBind[Dist, Set] {
private[this] def subApply[A, B](
builder: mutable.Builder[B, Set[B]],
arr: KArrow[Dist, A, B],
aIt: Iterator[A]
): TailRec[mutable.Builder[B, Set[B]]] =
if (aIt.isEmpty) {
done(builder)
} else {
val a = aIt.next
arr match {
case KMap(f) =>
val b = f(a)
builder += b
subApply(builder, arr, aIt)
case KFlatMap(f) =>
val d = f(a)
d.supportTailRec.flatMap { bs =>
builder ++= bs
subApply(builder, arr, aIt)
}
}
}
override def apply[A, B](
what: Set[A],
arr: KArrow[Dist, A, B]
): TailRec[Set[B]] =
subApply(Set.newBuilder[B], arr, what.iterator).map {
_.result
}
}
private final case class Bind[Z, A](context: Dist[Z], tal: TAL[Dist, Z, A]) extends Dist[A] {
override protected def sampleTailRec(random: Random): TailRec[A] =
context.sampleTailRec(random).flatMap { z =>
tal.run[Identity](z, new SampleTransBind(random))
}
override protected def supportTailRec: TailRec[Set[A]] =
context.supportTailRec.flatMap { zs =>
if (zs.isEmpty) {
done(Set.empty[A])
} else {
tal.run(zs, SupportTransBind)
}
}
}
@tailrec
private[this] def subCategoricalSample[A](
p: Double,
s: Double,
last: Option[A],
elemIt: Iterator[(A, Prob)]
): Option[A] =
if (elemIt.hasNext) {
val (elem, prob) = elemIt.next()
val t = s + prob.value
val seen = Some(elem)
if (t > p) {
seen
} else {
subCategoricalSample(p, t, seen, elemIt)
}
} else {
last
}
private final case class Categorical[A](elems: SortedMap[A, Prob]) extends Dist[A] {
override protected def sampleTailRec(random: Random): TailRec[A] =
subCategoricalSample(random.nextDouble(), 0, None, elems.iterator) match {
case None => throw EmptyException
case Some(v) => done(v)
}
override protected def supportTailRec: TailRec[Set[A]] =
done(elems.keySet)
}
private final case class Uniform[A](elems: IndexedSeq[A]) extends Dist[A] {
override protected def sampleTailRec(random: Random): TailRec[A] =
done(elems(random.nextInt(elems.size)))
override protected def supportTailRec: TailRec[Set[A]] =
done(elems.toSet)
}
def apply[A](value: A): Dist[A] =
Pure(value)
// TODO sum prob and normalize. move to IndexedSeq
def categorical[A](elems: SortedMap[A, Prob]): Dist[A] =
if (elems.isEmpty) {
throw EmptyException
} else {
Categorical(elems)
}
def uniform[A](elems: IndexedSeq[A]): Dist[A] =
if (elems.isEmpty) {
throw EmptyException
} else {
Uniform(elems)
}
}
object DistMain {
def main(args: Array[String]): Unit = {
val s = 42
val r = new Random(s)
val p = for {
x <- Dist.uniform(IndexedSeq(-1, 1))
y <- if (x > 0) Dist.uniform(IndexedSeq("a")) else Dist.uniform(IndexedSeq("b", "c"))
} yield {
y
}
println(p.sample(r))
println(p.support)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment