Last active
February 11, 2019 06:24
-
-
Save rxwei/061ef3c70be3adcfe9bea67369d44701 to your computer and use it in GitHub Desktop.
Function conforming to Differentiable
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// NOTE: The where clause is needed because of SR-9595. | |
struct Fn<T : Differentiable, U : Differentiable> | |
where T.TangentVector : AdditiveArithmetic, T.CotangentVector : AdditiveArithmetic, | |
U.TangentVector : AdditiveArithmetic, U.CotangentVector : AdditiveArithmetic { | |
let original: (T) -> U | |
let jvp: (T) -> (value: U, differential: (T.TangentVector) -> U.TangentVector) | |
let vjp: (T) -> (value: U, pullback: (U.CotangentVector) -> T.CotangentVector) | |
} | |
extension Fn : Equatable where U : Equatable { | |
static func == (lhs: Fn, rhs: Fn) -> Bool { return false } | |
} | |
extension Fn : AdditiveArithmetic where U : AdditiveArithmetic { | |
static var zero: Fn { | |
func orig(_ x: T) -> U { return .zero } | |
func jvp(_ x: T) -> (U, (T.TangentVector) -> U.TangentVector) { | |
return (orig(x), { _ in .zero }) | |
} | |
func vjp(_ x: T) -> (U, (U.CotangentVector) -> T.CotangentVector) { | |
return (orig(x), { _ in .zero }) | |
} | |
return Fn(original: orig, jvp: jvp, vjp: vjp) | |
} | |
static func + (lhs: Fn, rhs: Fn) -> Fn { | |
fatalError() | |
} | |
static func - (lhs: Fn, rhs: Fn) -> Fn { | |
fatalError() | |
} | |
} | |
extension Fn : Differentiable { | |
typealias TangentVector = Fn<T, U.TangentVector> | |
typealias CotangentVector = Fn<T, U.CotangentVector> | |
typealias AllDifferentiableVariables = Fn<T, U.AllDifferentiableVariables> | |
var allDifferentiableVariables: Fn<T, U.AllDifferentiableVariables> { | |
get { | |
func newOrig(_ x: T) -> U.AllDifferentiableVariables { | |
return original(x).allDifferentiableVariables | |
} | |
func newJVP(_ x: T) -> (U.AllDifferentiableVariables, (T.TangentVector) -> U.TangentVector) { | |
let (value: y, differential: df) = jvp(x) | |
return (y.allDifferentiableVariables, df) | |
} | |
func newVJP(_ x: T) -> (U.AllDifferentiableVariables, (U.CotangentVector) -> T.CotangentVector) { | |
let (value: y, pullback: pb) = vjp(x) | |
return (y.allDifferentiableVariables, pb) | |
} | |
return Fn<T, U.AllDifferentiableVariables>( | |
original: newOrig, jvp: newJVP, vjp: newVJP | |
) | |
} | |
set { | |
fatalError() | |
} | |
} | |
func moved(along direction: Fn<T, U.TangentVector>) -> Fn<T, U> { | |
func newOrig(_ x: T) -> U { | |
return original(x).moved(along: direction.original(x)) | |
} | |
func newJVP(_ x: T) -> (U, (T.TangentVector) -> U.TangentVector) { | |
let (value: y, differential: df) = jvp(x) | |
return (y.moved(along: direction.original(x)), df) | |
} | |
func newVJP(_ x: T) -> (U, (U.CotangentVector) -> T.CotangentVector) { | |
let (value: y, pullback: pb) = vjp(x) | |
return (y.moved(along: direction.original(x)), pb) | |
} | |
return Fn<T, U>(original: newOrig, jvp: newJVP, vjp: newVJP) | |
} | |
func tangentVector(from cotangent: Fn<T, U.CotangentVector>) -> Fn<T, U.TangentVector> { | |
func newOrig(_ x: T) -> U.TangentVector { | |
return original(x).tangentVector(from: cotangent.original(x)) | |
} | |
func newJVP(_ x: T) -> (U.TangentVector, (T.TangentVector) -> U.TangentVector) { | |
let (value: y, differential: df) = jvp(x) | |
return (y.tangentVector(from: cotangent.original(x)), df) | |
} | |
func newVJP(_ x: T) -> (U.TangentVector, (U.CotangentVector) -> T.CotangentVector) { | |
let (value: y, pullback: pb) = vjp(x) | |
return (y.tangentVector(from: cotangent.original(x)), pb) | |
} | |
return Fn<T, U.TangentVector>(original: newOrig, jvp: newJVP, vjp: newVJP) | |
} | |
} | |
func apply<T, U>(_ f: Fn<T, U>, _ x: T) -> U { | |
return f.original(x) | |
} | |
func applyJVP<T, U>(_ f: Fn<T, U>, _ x: T) -> (value: U, differential: (Fn<T, U.TangentVector>, T.TangentVector) -> U.TangentVector) { | |
let (value, df) = f.jvp(x) | |
func differential(vf: Fn<T, U.TangentVector>, vx: T.TangentVector) -> U.TangentVector { | |
return vf.original(x) + df(vx) | |
} | |
return (value, differential) | |
} | |
func applyVJP<T, U>(_ f: Fn<T, U>, _ x: T) -> (value: U, differential: (U.CotangentVector) -> (Fn<T, U.CotangentVector>, T.CotangentVector)) { | |
let (value, pb) : (U, (U.CotangentVector) -> T.CotangentVector) = f.vjp(x) | |
func pullback(_ v: U.CotangentVector) -> (Fn<T, U.CotangentVector>, T.CotangentVector) { | |
// Trivial constant! Derivative is zero. | |
func fnCotanOrig(x: T) -> U.CotangentVector { | |
return v | |
} | |
func fnCotanJVP(x: T) -> (U.CotangentVector, (T.TangentVector) -> U.CotangentVector) { | |
return (v, { _ in .zero }) | |
} | |
func fnCotanVJP(x: T) -> (U.CotangentVector, (U.TangentVector) -> T.CotangentVector) { | |
return (v, { _ in .zero }) | |
} | |
let fnCotan = Fn<T, U.CotangentVector>(original: fnCotanOrig, jvp: fnCotanJVP, vjp: fnCotanVJP) | |
return (fnCotan, pb(v)) | |
} | |
return (value, pullback) | |
} | |
// NOTE: The where clause is needed because of SR-9595. | |
struct Pair<T : Differentiable, U : Differentiable> : Differentiable | |
where T.TangentVector : AdditiveArithmetic, T.CotangentVector : AdditiveArithmetic, | |
U.TangentVector : AdditiveArithmetic, U.CotangentVector : AdditiveArithmetic { | |
typealias TangentVector = Pair<T.TangentVector, U.TangentVector> | |
typealias CotangentVector = Pair<T.CotangentVector, U.CotangentVector> | |
var x: T | |
var y: U | |
init(_ x: T, _ y: U) { self.x = x; self.y = y } | |
func moved(along d: TangentVector) -> Pair { | |
return Pair(x.moved(along: d.x), y.moved(along: d.y)) | |
} | |
func tangentVector(from cotangent: CotangentVector) -> TangentVector { | |
return TangentVector(x.tangentVector(from: cotangent.x), y.tangentVector(from: cotangent.y)) | |
} | |
} | |
extension Pair : Equatable where T : Equatable, U : Equatable {} | |
extension Pair : AdditiveArithmetic where T : AdditiveArithmetic, U : AdditiveArithmetic {} | |
func curry<T, U, V>(_ f: Fn<Pair<T, U>, V>) -> Fn<T, Fn<U, V>> { | |
func outerOrig(x: T) -> Fn<U, V> { | |
func innerOrig(y: U) -> V { | |
return f.original(Pair(x, y)) | |
} | |
// These fall out of differentiation. | |
func innerJVP(y: U) -> (value: V, differential: (U.TangentVector) -> V.TangentVector) { | |
let (z, df): (V, (Pair<T, U>.TangentVector) -> V.TangentVector) = f.jvp(Pair(x, y)) | |
return (z, { v in df(Pair(.zero, v)) }) | |
} | |
func innerVJP(y: U) -> (value: V, pullback: (V.CotangentVector) -> U.CotangentVector) { | |
let (z, pb): (V, (V.CotangentVector) -> Pair<T, U>.CotangentVector) = f.vjp(Pair(x, y)) | |
return (z, { v in pb(v).y }) | |
} | |
return Fn<U, V>(original: innerOrig, jvp: innerJVP, vjp: innerVJP) | |
} | |
// These fall out of differentiation. | |
func outerJVP(x: T) -> (Fn<U, V>, (T.TangentVector) -> Fn<U, V.TangentVector>) { | |
fatalError() | |
} | |
func outerVJP(x: T) -> (Fn<U, V>, (Fn<U, V.CotangentVector>) -> T.CotangentVector) { | |
fatalError() | |
} | |
return Fn<T, Fn<U, V>>(original: outerOrig, jvp: outerJVP, vjp: outerVJP) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment