Skip to content

Instantly share code, notes, and snippets.

@rxwei
Last active February 11, 2019 06:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rxwei/061ef3c70be3adcfe9bea67369d44701 to your computer and use it in GitHub Desktop.
Save rxwei/061ef3c70be3adcfe9bea67369d44701 to your computer and use it in GitHub Desktop.
Function conforming to Differentiable
// NOTE: The where clause is needed because of SR-9595.
struct Fn<T : Differentiable, U : Differentiable>
where T.TangentVector : AdditiveArithmetic, T.CotangentVector : AdditiveArithmetic,
U.TangentVector : AdditiveArithmetic, U.CotangentVector : AdditiveArithmetic {
let original: (T) -> U
let jvp: (T) -> (value: U, differential: (T.TangentVector) -> U.TangentVector)
let vjp: (T) -> (value: U, pullback: (U.CotangentVector) -> T.CotangentVector)
}
extension Fn : Equatable where U : Equatable {
static func == (lhs: Fn, rhs: Fn) -> Bool { return false }
}
extension Fn : AdditiveArithmetic where U : AdditiveArithmetic {
static var zero: Fn {
func orig(_ x: T) -> U { return .zero }
func jvp(_ x: T) -> (U, (T.TangentVector) -> U.TangentVector) {
return (orig(x), { _ in .zero })
}
func vjp(_ x: T) -> (U, (U.CotangentVector) -> T.CotangentVector) {
return (orig(x), { _ in .zero })
}
return Fn(original: orig, jvp: jvp, vjp: vjp)
}
static func + (lhs: Fn, rhs: Fn) -> Fn {
fatalError()
}
static func - (lhs: Fn, rhs: Fn) -> Fn {
fatalError()
}
}
extension Fn : Differentiable {
typealias TangentVector = Fn<T, U.TangentVector>
typealias CotangentVector = Fn<T, U.CotangentVector>
typealias AllDifferentiableVariables = Fn<T, U.AllDifferentiableVariables>
var allDifferentiableVariables: Fn<T, U.AllDifferentiableVariables> {
get {
func newOrig(_ x: T) -> U.AllDifferentiableVariables {
return original(x).allDifferentiableVariables
}
func newJVP(_ x: T) -> (U.AllDifferentiableVariables, (T.TangentVector) -> U.TangentVector) {
let (value: y, differential: df) = jvp(x)
return (y.allDifferentiableVariables, df)
}
func newVJP(_ x: T) -> (U.AllDifferentiableVariables, (U.CotangentVector) -> T.CotangentVector) {
let (value: y, pullback: pb) = vjp(x)
return (y.allDifferentiableVariables, pb)
}
return Fn<T, U.AllDifferentiableVariables>(
original: newOrig, jvp: newJVP, vjp: newVJP
)
}
set {
fatalError()
}
}
func moved(along direction: Fn<T, U.TangentVector>) -> Fn<T, U> {
func newOrig(_ x: T) -> U {
return original(x).moved(along: direction.original(x))
}
func newJVP(_ x: T) -> (U, (T.TangentVector) -> U.TangentVector) {
let (value: y, differential: df) = jvp(x)
return (y.moved(along: direction.original(x)), df)
}
func newVJP(_ x: T) -> (U, (U.CotangentVector) -> T.CotangentVector) {
let (value: y, pullback: pb) = vjp(x)
return (y.moved(along: direction.original(x)), pb)
}
return Fn<T, U>(original: newOrig, jvp: newJVP, vjp: newVJP)
}
func tangentVector(from cotangent: Fn<T, U.CotangentVector>) -> Fn<T, U.TangentVector> {
func newOrig(_ x: T) -> U.TangentVector {
return original(x).tangentVector(from: cotangent.original(x))
}
func newJVP(_ x: T) -> (U.TangentVector, (T.TangentVector) -> U.TangentVector) {
let (value: y, differential: df) = jvp(x)
return (y.tangentVector(from: cotangent.original(x)), df)
}
func newVJP(_ x: T) -> (U.TangentVector, (U.CotangentVector) -> T.CotangentVector) {
let (value: y, pullback: pb) = vjp(x)
return (y.tangentVector(from: cotangent.original(x)), pb)
}
return Fn<T, U.TangentVector>(original: newOrig, jvp: newJVP, vjp: newVJP)
}
}
func apply<T, U>(_ f: Fn<T, U>, _ x: T) -> U {
return f.original(x)
}
func applyJVP<T, U>(_ f: Fn<T, U>, _ x: T) -> (value: U, differential: (Fn<T, U.TangentVector>, T.TangentVector) -> U.TangentVector) {
let (value, df) = f.jvp(x)
func differential(vf: Fn<T, U.TangentVector>, vx: T.TangentVector) -> U.TangentVector {
return vf.original(x) + df(vx)
}
return (value, differential)
}
func applyVJP<T, U>(_ f: Fn<T, U>, _ x: T) -> (value: U, differential: (U.CotangentVector) -> (Fn<T, U.CotangentVector>, T.CotangentVector)) {
let (value, pb) : (U, (U.CotangentVector) -> T.CotangentVector) = f.vjp(x)
func pullback(_ v: U.CotangentVector) -> (Fn<T, U.CotangentVector>, T.CotangentVector) {
// Trivial constant! Derivative is zero.
func fnCotanOrig(x: T) -> U.CotangentVector {
return v
}
func fnCotanJVP(x: T) -> (U.CotangentVector, (T.TangentVector) -> U.CotangentVector) {
return (v, { _ in .zero })
}
func fnCotanVJP(x: T) -> (U.CotangentVector, (U.TangentVector) -> T.CotangentVector) {
return (v, { _ in .zero })
}
let fnCotan = Fn<T, U.CotangentVector>(original: fnCotanOrig, jvp: fnCotanJVP, vjp: fnCotanVJP)
return (fnCotan, pb(v))
}
return (value, pullback)
}
// NOTE: The where clause is needed because of SR-9595.
struct Pair<T : Differentiable, U : Differentiable> : Differentiable
where T.TangentVector : AdditiveArithmetic, T.CotangentVector : AdditiveArithmetic,
U.TangentVector : AdditiveArithmetic, U.CotangentVector : AdditiveArithmetic {
typealias TangentVector = Pair<T.TangentVector, U.TangentVector>
typealias CotangentVector = Pair<T.CotangentVector, U.CotangentVector>
var x: T
var y: U
init(_ x: T, _ y: U) { self.x = x; self.y = y }
func moved(along d: TangentVector) -> Pair {
return Pair(x.moved(along: d.x), y.moved(along: d.y))
}
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
return TangentVector(x.tangentVector(from: cotangent.x), y.tangentVector(from: cotangent.y))
}
}
extension Pair : Equatable where T : Equatable, U : Equatable {}
extension Pair : AdditiveArithmetic where T : AdditiveArithmetic, U : AdditiveArithmetic {}
func curry<T, U, V>(_ f: Fn<Pair<T, U>, V>) -> Fn<T, Fn<U, V>> {
func outerOrig(x: T) -> Fn<U, V> {
func innerOrig(y: U) -> V {
return f.original(Pair(x, y))
}
// These fall out of differentiation.
func innerJVP(y: U) -> (value: V, differential: (U.TangentVector) -> V.TangentVector) {
let (z, df): (V, (Pair<T, U>.TangentVector) -> V.TangentVector) = f.jvp(Pair(x, y))
return (z, { v in df(Pair(.zero, v)) })
}
func innerVJP(y: U) -> (value: V, pullback: (V.CotangentVector) -> U.CotangentVector) {
let (z, pb): (V, (V.CotangentVector) -> Pair<T, U>.CotangentVector) = f.vjp(Pair(x, y))
return (z, { v in pb(v).y })
}
return Fn<U, V>(original: innerOrig, jvp: innerJVP, vjp: innerVJP)
}
// These fall out of differentiation.
func outerJVP(x: T) -> (Fn<U, V>, (T.TangentVector) -> Fn<U, V.TangentVector>) {
fatalError()
}
func outerVJP(x: T) -> (Fn<U, V>, (Fn<U, V.CotangentVector>) -> T.CotangentVector) {
fatalError()
}
return Fn<T, Fn<U, V>>(original: outerOrig, jvp: outerJVP, vjp: outerVJP)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment