Skip to content

Instantly share code, notes, and snippets.

@rxwei
Last active February 24, 2019 00:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rxwei/be724f6634b9536518e4f9d378035be5 to your computer and use it in GitHub Desktop.
Save rxwei/be724f6634b9536518e4f9d378035be5 to your computer and use it in GitHub Desktop.
Currying differentiable functions in Swift. Full version with transpose here: https://gist.github.com/rxwei/1cfb98027f656adb1ebfa8af56826c97.
// Function-as-a-differentiable-type rule:
// Tangent space: ((T...) -> U...)' = Any
// Cotangent space: ((T...) -> U...)'* = Any
// Why? Because when a function value is varying, what's varying is it's context.
// In general cases, we need this to be a constrained existential with an
// `AdditiveArithmetic` conformance for its `.zero` and `+`, and `Differentiable`
// for being able to transpose between differential and a pullback.
// New associated function type calculation rules:
// original: (T...) -> (U...)
// jvp: (T...) -> (value: U, differential: (Any, T...') -> (U...'))
// jvp: (T...) -> (value: U, pullback: (U...'*) -> (Any, T...'*))
func curry<T: Differentiable, U: Differentiable, V: Differentiable>(
_ 𝑓: @escaping @differentiable (T, U) -> V
) -> @differentiable (T) -> @differentiable (U) -> V {
// Outer function.
let f: @differentiable (T) -> @differentiable (U) -> (V) = makeDifferentiable { x in
// Inner function.
let g: @differentiable (U) -> V = makeDifferentiable { y in
let (z, φ٭ᶻ) = valueWithPullback(at: x, y, in: 𝑓)
let φ٭ᵍ: (V.CotangentVector) -> (Any, U.CotangentVector) = { z̅ in
let (x̅, y̅) = φ٭ᶻ(z̅)
return (x̅ as Any, y̅)
}
return (value: z, pullback: φ٭ᵍ)
}
let φ٭ᶠ: (Any) -> (Any, T.CotangentVector) = { g̅ in
return ((), g̅ as! T.CotangentVector)
}
return (value: g, pullback: φ٭ᶠ)
}
return f
}
// Turns the VJP for a thick function into a `@differentiable` function.
func makeDifferentiable<T: Differentiable, U: Differentiable>(
from vjp: (T) -> (value: U, pullback: (U.CotangentVector) -> (Any, T.CotangentVector))
) -> @differentiable (T) -> U {
fatalError()
}
// Turns the VJP for a thick function whose result is a `@differentiable` function into a `@differentiable` function.
func makeDifferentiable<T: Differentiable, U: Differentiable, V: Differentiable>(
from vjp: (T) -> (value: @differentiable (U) -> V, pullback: (Any) -> (Any, T.CotangentVector))
) -> @differentiable (T) -> @differentiable (U) -> V {
fatalError()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment