Last active
January 21, 2020 06:42
-
-
Save rxwei/edcbc7c06cff8d8fa47806974cdd976e to your computer and use it in GitHub Desktop.
Incremental computation with differentiable programming https://arxiv.org/pdf/1312.0658.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// To be shared on Swift Forums. | |
// Compile with: | |
// swiftc -Xllvm -enable-experimental-cross-file-derivative-registration -enable-experimental-forward-mode-differentiation main.swift | |
// MARK: - Make integers differentiable | |
extension Int: Differentiable { | |
public typealias TangentVector = Int | |
} | |
extension Int { | |
@derivative(of: +) | |
@usableFromInline | |
static func derivativeOfAdd(_ x: Int, _ y: Int) -> (value: Int, differential: (Int, Int) -> Int) { | |
(x + y, { dx, dy in dx + dy }) | |
} | |
@derivative(of: +) | |
@usableFromInline | |
static func derivativeOfAdd(_ x: Int, _ y: Int) -> (value: Int, pullback: (Int) -> (Int, Int)) { | |
(x + y, { v in (v, v) }) | |
} | |
} | |
// MARK: - Add missing derivatives | |
extension Array where Element: Differentiable { | |
@derivative(of: Array.differentiableReduce, wrt: (self, initialResult)) | |
@usableFromInline | |
func derivativeOfDifferentiableReduce<Result: Differentiable>( | |
_ initialResult: Result, | |
_ nextPartialResult: @differentiable (Result, Element) -> Result | |
) -> ( | |
value: Result, | |
differential: (Array.TangentVector, Result.TangentVector) -> Result.TangentVector | |
) { | |
var differentials: [(Result.TangentVector, Element.TangentVector) -> Result.TangentVector] = [] | |
let count = self.count | |
differentials.reserveCapacity(count) | |
var result = initialResult | |
for element in self { | |
let (y, df) = Swift.valueWithDifferential(at: result, element, in: nextPartialResult) | |
result = y | |
differentials.append(df) | |
} | |
return (value: result, differential: { dSelf, dInitial in | |
var dResult = dInitial | |
for (dElement, df) in zip(dSelf.base, differentials) { | |
dResult = df(dResult, dElement) | |
} | |
return dResult | |
}) | |
} | |
@usableFromInline | |
@derivative(of: differentiableReduce) | |
internal func _vjpDifferentiableReduce<Result: Differentiable>( | |
_ initialResult: Result, | |
_ nextPartialResult: @differentiable (Result, Element) -> Result | |
) -> ( | |
value: Result, | |
pullback: (Result.TangentVector) -> (Array.TangentVector, Result.TangentVector) | |
) { | |
var pullbacks: [(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)] = [] | |
let count = self.count | |
pullbacks.reserveCapacity(count) | |
var result = initialResult | |
for element in self { | |
let (y, pb) = | |
Swift.valueWithPullback(at: result, element, in: nextPartialResult) | |
result = y | |
pullbacks.append(pb) | |
} | |
return (value: result, pullback: { tangent in | |
var resultTangent = tangent | |
var elementTangents = TangentVector([]) | |
elementTangents.base.reserveCapacity(count) | |
for pullback in pullbacks.reversed() { | |
let (newResultTangent, elementTangent) = pullback(resultTangent) | |
resultTangent = newResultTangent | |
elementTangents.base.append(elementTangent) | |
} | |
return (TangentVector(elementTangents.base.reversed()), resultTangent) | |
}) | |
} | |
} | |
// MARK: - Formalize distances | |
protocol Distanceable: Differentiable { | |
func distance(to other: Self) -> TangentVector | |
} | |
extension Int: Distanceable {} | |
extension Array: Distanceable where Element: Differentiable, Element.TangentVector == Element { | |
func distance(to other: Array) -> TangentVector { | |
TangentVector(zip(other, self).map(-)) | |
} | |
} | |
// MARK: - Generic incrementalizer | |
func incrementalized<T: Distanceable, U: Distanceable>( | |
at x: T, in body: @differentiable (T) -> U | |
) -> (T) -> U { | |
let (y, differential) = valueWithDifferential(at: x, in: body) | |
return { newX in | |
let Δx = x.distance(to: newX) | |
let Δy = differential(Δx) | |
var y = y | |
y.move(along: Δy) | |
return y | |
} | |
} | |
// MARK: - Examples | |
func twoPlus(_ x: Int) -> Int { | |
2 + x | |
} | |
// An incrementalized `twoPlus(_:)` function. | |
let incrementalizedTwoPlus = incrementalized(at: 3, in: twoPlus) | |
print(incrementalizedTwoPlus(10)) // 12 | |
func sum(of elements: [Int]) -> Int { | |
elements.differentiableReduce(.zero, +) | |
} | |
// An incrementalized `sum(_:)` function. | |
let incrementalizedSum = incrementalized(at: [1, 2, 3, 4, 5], in: sum) | |
print(incrementalizedSum([10, 20, 30, 40, 50])) // 150 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment