Skip to content

Instantly share code, notes, and snippets.

@Blaisorblade
Last active August 29, 2015 14:06
Show Gist options
  • Save Blaisorblade/1622b0809effb9a56061 to your computer and use it in GitHub Desktop.
Save Blaisorblade/1622b0809effb9a56061 to your computer and use it in GitHub Desktop.
Trying to implement a typed tree transform ([T]Exp[T] => Exp[∆[T]]) in Scala, where ∆[_] is a type function, with some successes
// Also at https://gist.github.com/Blaisorblade/1622b0809effb9a56061
package ilc
/**
* Incremental lambda calculus - an attempt at a *typed* Scala implementation.
* Should you want more details on the particular transform, see http://inc-lc.github.io/ and our
* PLDI paper at http://www.informatik.uni-marburg.de/~pgiarrusso/papers/pldi14-ilc-author-final.pdf
* — the transform is in Fig. 4(g).
*
* However, I **quickly** wrote a minimal introduction here so that you can make sense of the code.
* I am not really explaining why this is useful — there's the paper for it.
*
* The goal is to turn a program `t : σ → τ` into their incremental version/derivative `Derive(t)`,
* which takes an input value and its change to an output change.
* Hence `Derive(t)` has type `∆ (σ → τ) = σ → ∆σ → ∆τ`. More in general,
* you can say `Derive`'s type is `[T] T => ∆T`.
*
* (I won't explain here what happens to the context, because the HOAS representation of functions
* hides that away anyway, so contexts don't show up in the code. The type I gave is true for closed terms).
*
* At the same time, the derivative has the type of a change (because it is a nil change, but
* I won't try explaining that here).
*
* This transform takes a term t : T to its nil change/derivative (it's the same) Derive(t) : ∆T.
* ∆T is the type of changes to T, and it's a type function. On functions it's:
*
* ∆ (σ → τ) = σ → ∆σ → ∆τ
*
* Here's the transform on terms:
*
* Derive(λx. t) = λx dx. Derive(t)
* Derive(s t) = Derive(s) t Derive(t)
* Derive(x) = dx
*
* You also need extra cases for base types and constants. Below are examples for integers and arithmetic.
*
* Tested with Scala 2.11.0-M8 and 2.11.2 — 2.10.x would unsoundly allow things which fail here, see below.
* You can even run it to observe the transform in action!
*/
/*
* We represent the universe of simple types in Scala at the the type-level.
* Implementation inspired by http://apocalisp.wordpress.com/2010/06/13/type-level-programming-in-scala-part-3-boolean/.
*/
sealed trait Type {
type Eval
type DT <: Type
}
sealed trait BaseType[BaseT] extends Type {
type Eval = BaseT
}
sealed trait BaseInt extends BaseType[Int] {
type DT = BaseInt
}
/*
* I tried adding subtyping here using Scala's subtyping, but it does not work.
* Eval uses BaseS and BaseT invariantly.
* Note that this would have been (unsoundly) accepted from 2.8.x to 2.10.x.
* To correctly encode subtyping, the next step might be to also embed
* subtyping judgements for lambda_<:, but this is not really practical
* *for integration with LMS*. However, it would still allow showing that derive
* works for lambda_sub, especially with a meta-interpreter of subtyping judgements.
*/
sealed trait =>:[SPar <: Type, TPar <: Type] extends Type {
type S = SPar
type T = TPar
type BaseS = S#Eval
type BaseT = T#Eval
type Eval = BaseS => BaseT
type DT = S =>: S#DT =>: T#DT
}
//A first attempt at writing the type.
//def derive[T <: Type](v: T#Eval): T#DT#Eval = ???
//Cool! It worked! Let's move on to the real thing.
// A typeclass for (erased) change structures.
trait ΔBase[ReprT <: Type] {
type T = ReprT#Eval
type DT = ReprT#DT#Eval
def ⊕(t: T, dt: DT): T
def ⊖(tNew: T, tOld: T): DT
def ∘(dt1: DT, dt2: DT): DT
}
object Ops {
implicit class InfixChangeValueOps[ReprT <: Type](dt: ReprT#DT#Eval)(implicit Δt: ΔBase[ReprT]) {
type DT = ReprT#DT#Eval
def ∘(dt2: DT): DT = Δt.∘(dt, dt2)
}
implicit class InfixBaseValueOps[ReprT <: Type](t: ReprT#Eval)(implicit Δt: ΔBase[ReprT]) {
type T = ReprT#Eval
type DT = ReprT#DT#Eval
def ⊕(dt: DT): T = Δt.⊕(t, dt)
def ⊖(tOld: T): DT = Δt.⊖(t, tOld)
}
}
import Ops._
class ΔInt extends ΔBase[BaseInt] {
def ⊕(t: T, dt: DT): T = t + dt
def ⊖(tNew: T, tOld: T): DT = tNew - tOld
def ∘(dt1: DT, dt2: DT): DT = dt1 + dt2
}
class ΔFun[ReprS <: Type, ReprT <: Type](implicit Δs: ΔBase[ReprS], Δt: ΔBase[ReprT]) extends ΔBase[ReprS =>: ReprT] {
def ⊕(t: T, dt: DT): T =
x => t(x) ⊕ dt(x)(x ⊖ x)
def ⊖(tNew: T, tOld: T): DT =
x => dx => tNew(x ⊕ dx) ⊖ tOld(x)
def ∘(dt1: DT, dt2: DT): DT =
// ??? Is this definition correct?
x => dx => dt1(x)(dx) ∘ dt2(x)(dx)
}
//Theoretically useful, but I never needed this.
//type Δ[T <: Type, DTP] = ΔBase[T] { type DT = DTP }
//Can we synthesize an instance by looking at the type? That would require macros. Luckily, we might not need that.
//def f[T <: Type]: ΔBase[T] = ???
trait Exp[TP <: Type] {
type T = TP
def derive: Exp[T#DT]
}
final case class Num(t: Int) extends Exp[BaseInt] {
def derive: Exp[T#DT] = Num(0)
}
case class Plus(a: Exp[BaseInt], b: Exp[BaseInt]) extends Exp[BaseInt] {
def derive: Exp[T#DT] = Plus(a.derive, b.derive)
}
case class App[S <: Type, T <: Type](fun: Exp[S =>: T], arg: Exp[S]) extends Exp[T] {
def derive: Exp[T#DT] =
App(App(fun.derive, arg), arg.derive)
}
trait Name {
def name: String
def derive: Name = DerivedName(this)
//XXX hack, we'd need a proper pretty-printer, but that's so much boilerplate that
//I won't bother for now.
override def toString = name
}
case class BaseName(name: String) extends Name
case class DerivedName(base: Name) extends Name {
def name = "d" + base.name
}
case class Var[T <: Type](name: Name) extends Exp[T] {
def derive: Var[T#DT] = Var(name.derive)
}
case class Fun[SPar <: Type, UPar <: Type](v: Var[SPar], body: Exp[UPar]) extends Exp[SPar =>: UPar] {
type S = SPar
type U = UPar
def derive: Exp[T#DT] =
Fun(v, Fun(v.derive, body.derive))
}
//Don't write S, T as params, since T is shadowed inside.
// We need to record also the variable we want...
case class HOASFun[SPar <: Type, TPar <: Type](v: Var[SPar], fun: Exp[SPar] => Exp[TPar]) extends Exp[SPar =>: TPar] {
def derive: Exp[T#DT] =
// ... so that we can specify it here. Without that, the code would
// typecheck, but pick the "wrong" variable when converting to a first-order
// representation, because derivation on variables is too "nominal" to be
// expressed in HOAS otherwise.
HOASFun(v, x => HOASFun(v.derive, dx => fun(x).derive))
override def toString = toFun.toString
def toFun: Exp[T] = Fun(v, fun(v))
}
trait HoasWrappers {
private var counter = -1
def fresh[S <: Type](): Var[S] = {
counter += 1
Var(BaseName("x" + counter))
}
def funBase[S <: Type, T <: Type](fun: Exp[S] => Exp[T]): Exp[S =>: T] =
HOASFun(fresh[S](), fun)
//Curried type application.
//Complete "signature":
// fun[S <: Type][T <: Type](fun: Exp[S] => Exp[T]): Exp[S =>: T] = funBase[S, T](fun)
//
//but typically, you only write:
// fun[S](x => body)
//and Exp[S] is used as type annotation for x.
def fun[S <: Type] = new CurriedFun[S]
class CurriedFun[S <: Type] {
def apply[T <: Type](fun: Exp[S] => Exp[T]) = funBase(fun)
}
}
case class Fix[T <: Type](body: Exp[T =>: T]) extends Exp[T] {
def derive: Exp[T#DT] =
Fix(App(body.derive, this))
//equivalent to:
//Fix(App(body.derive, Fix(body)))
}
//XXX: not a real test.
object DeriveTest extends scala.App with HoasWrappers {
//The order of terms is vals first, defs after, to ensure fresh vars are
//generated in the same order as they appear.
//
//This can be useful for inspecting that fresh variable generation happens in
//the operationally expected way (since no reduction is done, no fresh
//variable should be generated and discarded).
val id = fun[BaseInt](x => x)
val power = fix[BaseInt =>: BaseInt =>: BaseInt](power => fun(n => fun(exp => n /*more interesting body needed*/)))
def ap[S <: Type, T <: Type] = fun[S =>: T](f => fun[S](arg => App(f, arg)))
def fix[T <: Type](body: Exp[T] => Exp[T]): Exp[T] = Fix(fun(body))
println(id)
println(id.derive)
println(power)
println(power.derive)
val apInt = ap[BaseInt, BaseInt]
println(apInt)
println(apInt.derive)
}
//Writing derive as a pattern-matching method exposes a bug in type-refinement. See BugReport.scala
trait TryExternalDerive {
//So we need unsafeCoerce here:
def unsafeCoerce[T <: Type, U <: Type](a: Exp[T]): Exp[U] = a.asInstanceOf[Exp[U]]
//That's what GADTs are translated to anyway, when Scalac manages. Luckily, we know better.
def deriveVar[T <: Type](v: Var[T]): Var[T#DT] = Var(v.name.derive)
def derive[T <: Type](term: Exp[T]): Exp[T#DT] = {
term match {
case Num(n) =>
unsafeCoerce(Num(n))
case Fun(v, body) =>
unsafeCoerce(Fun(v, Fun(deriveVar(v), derive(body))))
case App(fun, arg) =>
App(App(derive(fun), arg), derive(arg))
case v: Var[_] =>
deriveVar(v)
}
}
}
/*
// A related bug report
object BugReportReduced {
trait Type {
type DT <: Type
}
trait BaseInt extends Type {
type DT = BaseInt
}
trait Exp[TP <: Type]
class Num extends Exp[BaseInt]
def derive[T <: Type](term: Exp[T]): Any = {
val res0: Exp[T] = term match { case _: Num => (??? : Exp[T#DT]) }
val res1: Exp[T#DT] = term match { case _: Num => (??? : Exp[T#DT]) }
res1: Exp[T#DT] // fails, the prefix of the type projection in the case body underwent wanted GADT refinement, but this is not reflected here.
val res1b = term match { case _ : Num => (new Num : Exp[T#DT]) } //works
res1b: Exp[T#DT] //fails
type X = T#DT
val res2 = term match { case _: Num => (??? : Exp[X]) }
res2: Exp[T#DT] // works
val res2b = term match { case _: Num => (new Num : Exp[X]) } //fails
res2b: Exp[T#DT] // this part works
}
def derive2[T <: Type](term: Exp[T]): Exp[T] = {
term match {
case _ : Num => (??? : Exp[BaseInt])
case _ : Num => (??? : Exp[T])
}
}
def derive2Harder[T <: Type](term: Exp[T]): Exp[T] = {
val res = term match {
//case _ : Num => (??? : Exp[BaseInt])
case _ : Num => (??? : Exp[T]) //also broken!
}
res
}
def derive3[T <: Type](term: Exp[T]): Exp[BaseInt] = {
term match {
case _ : Num => (??? : Exp[BaseInt])
case _ : Num => (??? : Exp[T])
}
}
def derive3Harder[T <: Type](term: Exp[T]): Exp[BaseInt] = {
val res = term match {
//case _ : Num => (??? : Exp[BaseInt])
case _ : Num => (??? : Exp[T])
}
res
}
}
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment