Skip to content

Instantly share code, notes, and snippets.

@rxwei
Last active February 21, 2019 09:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rxwei/1cfb98027f656adb1ebfa8af56826c97 to your computer and use it in GitHub Desktop.
Save rxwei/1cfb98027f656adb1ebfa8af56826c97 to your computer and use it in GitHub Desktop.
Currying differentiable functions (with pullback transpose).
// Function-as-a-differentiable-type rule:
// Tangent space: ((T...) -> U...)' = AnyDerivative
// Cotangent space: ((T...) -> U...)'* = AnyDerivative
// Why? Because when a function value is varying, what's varying is it's context.
// In general cases, we need this to be a constrained existential with an
// `AdditiveArithmetic` conformance for its `.zero` and `+`, and `Differentiable`
// for being able to transpose between differential and a pullback.
// New associated function type calculation rules:
// original: (T...) -> (U...)
// jvp: (T...) -> (value: U, differential: @differentiable (AnyDerivative, T...') -> (U...'))
// jvp: (T...) -> (value: U, pullback: @differentiable (U...'*) -> (AnyDerivative, T...'*))
// We make `AnyDerivative` be `Any` for now, but it should really be a constrained existential
// to support use cases other than currying. Swift does not support generalized existentials yet.
// typealias AnyDerivative = Differentiable & AdditiveArithmetic
typealias AnyDerivative = Any
func curry<T: Differentiable, U: Differentiable, V: Differentiable>(
_ 𝑓: @escaping @differentiable (T, U) -> V
) -> @differentiable (T) -> @differentiable (U) -> V {
// Outer function.
let f: @differentiable (T) -> @differentiable (U) -> (V) = makeDifferentiable { x in
// Inner function.
let g: @differentiable (U) -> V = makeDifferentiable { y in
let (z, φ٭𝑓) = valueWithPullback(at: x, y, in: 𝑓)
// Inner pullback.
let φ٭ᵍ: (V.CotangentVector) -> (AnyDerivative, U.CotangentVector) = { z̅ in
let (x̅, y̅) = φ٭𝑓(z̅)
return (x̅ as AnyDerivative, y̅)
}
// Inner differential, transposed from pullback.
let φᵍ: (AnyDerivative, U.TangentVector) -> V.TangentVector = { x̲, y̲ in
let φ𝑓 = pullback(at: .zero, in: φ٭𝑓)
return φ𝑓(x̲ as! T.CotangentVector, y̲)
}
return (value: z, differential: φᵍ, pullback: φ٭ᵍ)
}
// Outer pullback.
let φ٭ᶠ: (AnyDerivative) -> (AnyDerivative, T.CotangentVector) = { g̅ in
((), g̅ as! T.CotangentVector)
}
// Outer differential, transposed from pullback.
let φᶠ: (AnyDerivative, T.CotangentVector) -> AnyDerivative = { f̲, x̲ in
x̲ as AnyDerivative
}
return (value: g, differential: φᶠ, pullback: φ٭ᶠ)
}
return f
}
func makeDifferentiable<T: Differentiable, U: Differentiable>(
from bundle: (T) -> (value: U,
differential: @differentiable (AnyDerivative, T.TangentVector) -> U.TangentVector,
pullback: @differentiable (U.CotangentVector) -> (AnyDerivative, T.CotangentVector))
) -> @differentiable (T) -> U {
fatalError()
}
func makeDifferentiable<T: Differentiable, U: Differentiable, V: Differentiable>(
from bundle: (T) -> (value: @differentiable (U) -> V,
differential: @differentiable (AnyDerivative, T.TangentVector) -> AnyDerivative,
pullback: @differentiable (AnyDerivative) -> (AnyDerivative, T.CotangentVector))
) -> @differentiable (T) -> @differentiable (U) -> V {
fatalError()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment