Skip to content

Instantly share code, notes, and snippets.

@johnynek
Last active May 10, 2018 04:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save johnynek/68051e06f17847c87409e4add0cb5645 to your computer and use it in GitHub Desktop.
Save johnynek/68051e06f17847c87409e4add0cb5645 to your computer and use it in GitHub Desktop.
package gradfn
import shapeless.{HNil, :: => HCons, HList, Nat}
import shapeless.ops.nat.ToInt
import Nat.{_0, _1, _2}
import Vect.VectOps
trait NatTrans[F[_], G[_]] {
def apply[A](fn: F[A]): G[A]
}
trait Alignment[F[_], G[_], HF <: HList, HG <: HList] {
def apply(fs: HF)(nt: NatTrans[F, G]): HG
def reverse: Alignment[G, F, HG, HF]
}
object Alignment {
implicit def nilAlign[F[_], G[_]]: Alignment[F, G, HNil, HNil] =
new Alignment[F, G, HNil, HNil] {
def apply(fs: HNil)(nt: NatTrans[F, G]) = HNil
def reverse = nilAlign
}
implicit def consAlign[A, F[_], G[_], FTail <: HList, GTail <: HList](
implicit tailAlign: Alignment[F, G, FTail, GTail]): Alignment[F, G, F[A] HCons FTail, G[A] HCons GTail] =
new Alignment[F, G, F[A] HCons FTail, G[A] HCons GTail] {
def apply(fs: F[A] HCons FTail)(nt: NatTrans[F, G]): G[A] HCons GTail =
HCons(nt(fs.head), tailAlign(fs.tail)(nt))
def reverse = consAlign(tailAlign.reverse)
}
}
sealed trait Result
object Result {
case class Value[N <: Nat](data: Vect.T[N, Double]) extends Result
case class Cont[N <: Nat, R <: Result](fn: FnGen[N, R]) extends Result
}
sealed abstract class FnGen[N <: Nat, R <: Result] {
def apply(args: Vect.T[N, Double]): R
def grad: Vect.T[N, FnGen[N, R]]
}
object FnGen {
def zero[N <: Nat, M <: Nat](implicit toIntN: ToInt[N], toIntM: ToInt[M]): FnGen[N, Result.Value[M]] =
new FnGen[N, Result.Value[M]] {
def apply(args: Vect.T[N, Double]) = Result.Value(Vect.T.fill(0.0))
lazy val grad = Vect.T.fill(this)
}
}
//
// FnGen[_1, FnGen[_1, R._1]]
// example [x] => [y] => [x*y]
// Vect[_1, Fn
// grad: [[x] => [y] => [y]]
//
// FnGen[_2, FnGen[_2, R._4]]
// [x0, x1] => [y0, y1] => [x0*y0, x1*y0, x0*y1, x1*y1]
// Vect.T[_2, FnGen[_2, FnGen[_2, R._4]]]
// grad [[x0, x1] => [y0, y1] => [y0, 0, y1, 0]], [[x0, x1] => [y0, y1] => [0, y0, 0, y1]]
package gradfn
import shapeless.{ Nat, Succ }
import shapeless.ops.nat
import scala.reflect.ClassTag
object Vect {
trait NewType {
type V[N <: Nat, A]
def map[A, B, L <: Nat](v: V[L, A])(fn: A => B)(implicit ct: ClassTag[B], toInt: nat.ToInt[L]): V[L, B]
def map2[A, B, C, L <: Nat](v1: V[L, A], v2: V[L, B])(fn: (A, B) => C)(implicit ct: ClassTag[C], toInt: nat.ToInt[L]): V[L, C]
def get[A, L <: Nat, N <: Nat](v: V[L, A])(implicit diff: nat.Diff[L, Succ[N]], toIEv: nat.ToInt[N]): A
def getIdx[A, L <: Nat](v: V[L, A], n: Nat)(implicit diff: nat.Diff[L, Succ[n.N]], toIEv: nat.ToInt[n.N]): A =
get[A, L, n.N](v)
def copyFrom[A, L <: Nat](as: Array[A])(implicit toInt: nat.ToInt[L]): V[L, A]
def copyFromSized[A](as: Array[A], n: Nat)(implicit toInt: nat.ToInt[n.N]): V[n.N, A] =
copyFrom[A, n.N](as)
def empty[A, L <: Nat](implicit ct: ClassTag[A], toIEv: nat.ToInt[L]): V[L, A]
def emptySized[A](n: Nat)(implicit ct: ClassTag[A], toIEv: nat.ToInt[n.N]): V[n.N, A] =
empty[A, n.N]
def unsafeFrom[A, L <: Nat](as: Array[A])(implicit toInt: nat.ToInt[L]): V[L, A]
def fill[A, L <: Nat](a: A)(implicit ct: ClassTag[A], toInt: nat.ToInt[L]): V[L, A]
def reduce[A, L <: Nat](v: V[L, A], fn: (A, A) => A)(implicit diff: nat.Diff[L, Nat._1], toInt: nat.ToInt[L]): A
}
val T: NewType = new NewType {
type V[N <: Nat, A] = Array[A]
def get[A, L <: Nat, N <: Nat](v: V[L, A])(implicit diff: nat.Diff[L, Succ[N]], toIEv: nat.ToInt[N]): A =
v(toIEv())
def unsafeFrom[A, L <: Nat](as: Array[A])(implicit toInt: nat.ToInt[L]): V[L, A] = {
require(as.length == toInt(), s"length of Array is ${as.length} not ${toInt()}")
as
}
def copyFrom[A, L <: Nat](as: Array[A])(implicit toInt: nat.ToInt[L]): V[L, A] = {
implicit val classA: ClassTag[A] = ClassTag(as.getClass.getComponentType)
val res = new Array[A](toInt())
var idx = 0
while(idx < res.length) {
if (idx < as.length) {
res(idx) = as(idx)
}
idx += 1
}
res
}
def empty[A, L <: Nat](implicit ct: ClassTag[A], toIEv: nat.ToInt[L]): V[L, A] =
new Array[A](toIEv())
def map[A, B, L <: Nat](v: V[L, A])(fn: A => B)(implicit ct: ClassTag[B], toInt: nat.ToInt[L]): V[L, B] = {
val emptyB = empty[B, L]
var idx = 0
while(idx < emptyB.length) {
emptyB(idx) = fn(v(idx))
idx += 1
}
return emptyB
}
def map2[A, B, C, L <: Nat](v1: V[L, A], v2: V[L, B])(fn: (A, B) => C)(implicit ct: ClassTag[C], toInt: nat.ToInt[L]): V[L, C] = {
val emptyC = empty[C, L]
var idx = 0
while(idx < emptyC.length) {
emptyC(idx) = fn(v1(idx), v2(idx))
idx += 1
}
return emptyC
}
def fill[A, L <: Nat](a: A)(implicit ct: ClassTag[A], toInt: nat.ToInt[L]): V[L, A] =
Array.fill(toInt())(a)
def reduce[A, L <: Nat](v: V[L, A], fn: (A, A) => A)(implicit diff: nat.Diff[L, Nat._1], toInt: nat.ToInt[L]): A = {
var idx = 0
// we know the size >= 1
var res = v(idx)
idx += 1
val sz = toInt()
while(idx < sz) {
res = fn(res, v(idx))
idx += 1
}
res
}
}
implicit class VectOps[L <: Nat, A](val v: T.V[L, A]) extends AnyVal {
def map[B](fn: A => B)(implicit ct: ClassTag[B], toInt: nat.ToInt[L]): T.V[L, B] =
T.map(v)(fn)
def get[N <: Nat](implicit diff: nat.Diff[L, Succ[N]], toIEv: nat.ToInt[N]): A =
T.get(v)
def apply(n: Nat)(implicit diff: nat.Diff[L, Succ[n.N]], toIEv: nat.ToInt[n.N]): A =
T.getIdx(v, n)
def reduce(fn: (A, A) => A)(implicit diff: nat.Diff[L, Nat._1], toInt: nat.ToInt[L]): A =
T.reduce(v, fn)
}
type T[N <: Nat, A] = T.V[N, A]
}
package gradfn
import shapeless.{HNil, :: => HCons, HList, Nat}
import shapeless.ops.nat.ToInt
import Nat.{_0, _1, _2}
import scala.reflect.ClassTag
import Vect.VectOps
trait NatTrans[F[_], G[_]] {
def apply[A](fn: F[A]): G[A]
}
trait Alignment[F[_], G[_], HF <: HList, HG <: HList] {
def apply(fs: HF)(nt: NatTrans[F, G]): HG
def reverse: Alignment[G, F, HG, HF]
}
object Alignment {
implicit def nilAlign[F[_], G[_]]: Alignment[F, G, HNil, HNil] =
new Alignment[F, G, HNil, HNil] {
def apply(fs: HNil)(nt: NatTrans[F, G]) = HNil
def reverse = nilAlign
}
implicit def consAlign[A, F[_], G[_], FTail <: HList, GTail <: HList](
implicit tailAlign: Alignment[F, G, FTail, GTail]): Alignment[F, G, F[A] HCons FTail, G[A] HCons GTail] =
new Alignment[F, G, F[A] HCons FTail, G[A] HCons GTail] {
def apply(fs: F[A] HCons FTail)(nt: NatTrans[F, G]): G[A] HCons GTail =
HCons(nt(fs.head), tailAlign(fs.tail)(nt))
def reverse = consAlign(tailAlign.reverse)
}
}
trait Semi[@specialized(Double) A] {
def plus(a: A, b: A): A
}
object Semi {
implicit val doubleSem: Semi[Double] =
new Semi[Double] {
def plus(a: Double, b: Double) = a + b
}
def plus[A](a: A, b: A)(implicit sg: Semi[A]): A =
sg.plus(a, b)
}
trait Monoid[@specialized(Double) A] extends Semi[A] {
def zero: A
}
object Monoid {
implicit val doubleMonoid: Monoid[Double] =
new Monoid[Double] {
def zero = 0.0
def plus(a: Double, b: Double) = a + b
}
def zero[A](implicit sg: Monoid[A]): A =
sg.zero
implicit def vectMonoid[A: Monoid: ClassTag, L <: Nat: ToInt]: Monoid[Vect.T[L, A]] =
new Monoid[Vect.T[L, A]] {
val zero = Vect.T.fill(Monoid.zero[A])
def plus(a: Vect.T[L, A], b: Vect.T[L, A]) =
Vect.T.map2(a, b)(Semi.plus(_, _))
}
}
trait SemiRing[@specialized(Double) A] extends Monoid[A] {
def times(a: A, b: A): A
}
object SemiRing {
implicit val doubleSemiRing: SemiRing[Double] =
new SemiRing[Double] {
def zero = 0.0
def plus(a: Double, b: Double) = a + b
def times(a: Double, b: Double) = a * b
}
def times[A](a: A, b: A)(implicit sr: SemiRing[A]): A =
sr.times(a, b)
}
trait Space[F, A[_], C] {
implicit def ring: SemiRing[F]
implicit def algebra: A[C]
def scale(v: F, c: C): C
}
object Space {
implicit def spaceForVect[A, M <: Nat, R](implicit space: Space[A, Monoid, R]): Space[Vect.T[M, A], Monoid, Vect.T[M, R]] = ???
}
sealed abstract class FnGen[N <: Nat, @specialized(Double) R] {
def apply(args: Vect.T[N, Double]): R
def grad: Vect.T[N, FnGen[N, R]]
}
object FnGen {
implicit def fnMonoid[R: Monoid, L <: Nat: ToInt]: Monoid[FnGen[L, R]] =
new Monoid[FnGen[L, R]] {
val zero = Zero()
def plus(a: FnGen[L, R], b: FnGen[L, R]) =
Add(a, b)
}
implicit def spaceForFn[A, M <: Nat, R](implicit space: Space[A, Monoid, R]): Space[A, Monoid, FnGen[M, R]] = ???
case class Instance1(f: Double => Double, g: () => FnGen[_1, Double]) extends FnGen[_1, Double] {
def apply(args: Vect.T[_1, Double]) = f(args.get[_0])
lazy val grad = Vect.T.fill(g())
}
def instance1(f: Double => Double, g: => FnGen[_1, Double]): FnGen[_1, Double] =
Instance1(f, () => g)
def const1(x: Double): FnGen[_1, Double] =
instance1(_ => x, const1(0.0))
val ident1: FnGen[_1, Double] =
instance1(x => x, const1(1.0))
def pow(x: Double): FnGen[_1, Double] =
if (x == 0.0) const1(1.0)
else if (x == 1.0) ident1
else instance1(math.pow(_, x), pow(x - 1.0))
case class Zero[N <: Nat: ToInt, R: Monoid]() extends FnGen[N, R] {
def apply(args: Vect.T[N, Double]) = Monoid.zero
lazy val grad = Vect.T.fill(this)
}
case class Fill[N <: Nat: ToInt]() extends FnGen[_1, Vect.T[N, Double]] {
def apply(args: Vect.T[_1, Double]) = Vect.T.fill(args.get[_0])
val grad = Vect.T.fill(Const(Vect.T.fill(1.0)))
}
case class Const[N <: Nat: ToInt, R: Monoid](result: R) extends FnGen[N, R] {
def apply(args: Vect.T[N, Double]) = result
val grad = Vect.T.fill(Zero[N, R])
}
case class Add[N <: Nat: ToInt, R: Semi](f1: FnGen[N, R], f2: FnGen[N, R]) extends FnGen[N, R] {
def apply(args: Vect.T[N, Double]) = Semi.plus(f1(args), f2(args))
lazy val grad = Vect.T.map2(f1.grad, f2.grad)(Add(_, _))
}
case class All[N <: Nat, M <: Nat, R: ClassTag](fns: Vect.T[N, FnGen[M, R]]) extends FnGen[M, Vect.T[N, R]] {
def apply(args: Vect.T[M, Double]) = Vect.T.map(fns)(_(args))
lazy val grad = ???
// implicit val vct = Vect.T.vclassTag[gradfn.FnGen[M,R], M]
// Vect.T.map(fns) { fn => fn.grad }
// }
}
// (xy)' = (x'y) + (xy')
case class Multiply[N <: Nat: ToInt, R1, R2](f1: FnGen[N, R1], f2: FnGen[N, R2])(implicit space: Space[R1, Monoid, R2]) extends FnGen[N, R2] {
def apply(args: Vect.T[N, Double]): R2 = space.scale(f1(args), f2(args))
lazy val grad = Vect.T.map2(f1.grad, f2.grad) { (d1, d2) =>
import space.algebra
Add(Multiply(d1, f2), Multiply(f1, d2))
}
}
case class Scale[N <: Nat: ToInt, R1, R2](const: R1, fn: FnGen[N, R2])(implicit space: Space[R1, Monoid, R2]) extends FnGen[N, R2] {
def apply(args: Vect.T[N, Double]): R2 = space.scale(const, fn(args))
lazy val grad = Vect.T.map(fn.grad)(Scale(const, _))
}
case class Dot[N <: Nat: ToInt, R: ClassTag: Monoid](vector: Vect.T[N, R])(implicit space: Space[Double, Monoid, R]) extends FnGen[N, R] {
def apply(args: Vect.T[N, Double]): R =
Vect.T.ifEmpty(vector, Monoid.zero) { implicit diff =>
Vect.T.map2reduce(args, vector)(space.scale, space.algebra.plus(_, _))
}
lazy val grad = Vect.T.map(vector)(Const[N, R](_))
}
case class Trace[M <: Nat, N <: Nat, R: Monoid](vectFn: FnGen[M, Vect.T[N, R]]) extends FnGen[M, R] {
def sum[L <: Nat](vec: Vect.T[L, R]): R =
Vect.T.ifEmpty(vec, Monoid.zero) { implicit diff =>
vec.reduce(Semi.plus(_, _))
}
def apply(args: Vect.T[M, Double]): R = sum(vectFn(args))
lazy val grad = vectFn.grad.map(Trace(_))
}
// grad(g(f(x)) = (grad(f) * (grad(g)(f(x)))
case class Compose[M <: Nat: ToInt, N <: Nat: ToInt, R: ClassTag](
f1: FnGen[N, R],
f2: FnGen[M, Vect.T[N, Double]]
)(implicit space: Space[Double, Monoid, R]) extends FnGen[M, R] {
def apply(args: Vect.T[M, Double]): R = f1(f2(args))
lazy val grad = {
import space.algebra
val gradf1 = f1.grad
val g1f2: Vect.T[N, FnGen[M, R]] = gradf1.map(Compose[M, N, R](_, f2))
val all: FnGen[M, Vect.T[N, R]] = All(g1f2)
f2.grad.map { gf2: FnGen[M, Vect.T[N, Double]] =>
Trace(Multiply(gf2, all))
}
}
}
}
//
// FnGen[_1, FnGen[_1, R._1]]
// example [x] => [y] => [x*y]
// Vect[_1, Fn
// grad: [[x] => [y] => [y]]
//
// FnGen[_2, FnGen[_2, R._4]]
// [x0, x1] => [y0, y1] => [x0*y0, x1*y0, x0*y1, x1*y1]
// Vect.T[_2, FnGen[_2, FnGen[_2, R._4]]]
// grad [[x0, x1] => [y0, y1] => [y0, 0, y1, 0]], [[x0, x1] => [y0, y1] => [0, y0, 0, y1]]
object VectFn {
sealed trait Fn[N <: Nat] {
def apply(args: Vect.T[N, Double]): Double
def grad: Vect.T[N, Fn[N]]
}
object Fn {
implicit def fieldFn[N <: Nat: ToInt]: Field[Fn[N]] =
new Field[Fn[N]] {
val zero = VectFn.zero[N]
val one = VectFn.one[N]
def plus(a: Fn[N], b: Fn[N]): Fn[N] =
add(a, b)
def negate(a: Fn[N]): Fn[N] =
VectFn.negate(a)
def times(a: Fn[N], b: Fn[N]): Fn[N] =
multiply(a, b)
def inv(a: Fn[N]): Fn[N] =
compose(pow(-1.0), a)
def div(a: Fn[N], b: Fn[N]): Fn[N] =
divide(a, b)
}
}
type Fn1 = Fn[_1]
def instance1(f: Double => Double, g: => Fn[_1]): Fn[_1] =
new Fn[_1] {
def apply(args: Vect.T[_1, Double]): Double =
f(args.get[_0])
lazy val grad = Vect.T.fill[Fn[_1], _1](g)
}
// x' = 1
val identity: Fn1 =
instance1(x => x, const(1.0))
// c' = 0
def const[N <: Nat](c: Double)(implicit toInt: ToInt[N]): Fn[N] =
new Fn[N] {
def apply(args: Vect.T[N, Double]): Double = c
lazy val grad = Vect.T.fill(zero)
}
def zero[N <: Nat](implicit toInt: ToInt[N]): Fn[N] = const(0.0)
def one[N <: Nat](implicit toInt: ToInt[N]): Fn[N] = const(1.0)
def negate[N <: Nat: ToInt](f1: Fn[N]): Fn[N] =
new Fn[N] {
def apply(args: Vect.T[N, Double]): Double = -f1(args)
lazy val grad = f1.grad.map(negate(_))
}
// (x + y)' = x' + y'
def add[N <: Nat: ToInt](f1: Fn[N], f2: Fn[N]): Fn[N] =
new Fn[N] {
def apply(args: Vect.T[N, Double]): Double = f1(args) + f2(args)
lazy val grad = Vect.T.map2(f1.grad, f2.grad)(add(_, _))
}
// (x - y)' = (x + (-y))'
def subtract[N <: Nat: ToInt](f1: Fn[N], f2: Fn[N]): Fn[N] =
add(f1, negate(f2))
// (xy)' = (x'y) + (xy')
def multiply[N <: Nat: ToInt](f1: Fn[N], f2: Fn[N]): Fn[N] =
new Fn[N] {
def apply(args: Vect.T[N, Double]): Double = f1(args) * f2(args)
lazy val grad = Vect.T.map2(f1.grad, f2.grad) { (d1, d2) =>
add(multiply(d1, f2), multiply(f1, d2))
}
}
// (x/y)' = (x * (1/y))'
def divide[N <: Nat: ToInt](f1: Fn[N], f2: Fn[N]): Fn[N] =
multiply(f1, compose(pow(-1.0), f2))
// f1(f2(x))' = f1'(f2(x)) * f2'(x)
def compose[N <: Nat: ToInt](f1: Fn1, f2: Fn[N]): Fn[N] =
new Fn[N] {
def apply(args: Vect.T[N, Double]): Double = f1(Vect.T.fill(f2(args)))
lazy val grad = {
val gradf1 = f1.grad.get[_0]
val g1f2 = compose(gradf1, f2)
f2.grad.map(multiply(g1f2, _))
}
}
// (f1^k)' = k * f1^(k - 1) * f1'
def pow(k: Double): Fn1 =
if (k == 0.0) one
else if (k == 1.0) identity//f1
else instance1(x => math.pow(x, k), multiply(const[_1](k), pow(k - 1)))
// log(x)' = 1/x
val log: Fn1 =
instance1(x => math.log(x), pow(-1.0))
// (e^x)' = e^x
lazy val exp: Fn1 =
instance1(x => math.exp(x), exp)
}
trait Field[A] {
def zero: A
def one: A
def plus(a: A, b: A): A
def negate(a: A): A
def times(a: A, b: A): A
def inv(a: A): A
def div(a: A, b: A): A
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment