Skip to content

Instantly share code, notes, and snippets.

@rxwei
Last active January 21, 2020 06:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rxwei/edcbc7c06cff8d8fa47806974cdd976e to your computer and use it in GitHub Desktop.
Save rxwei/edcbc7c06cff8d8fa47806974cdd976e to your computer and use it in GitHub Desktop.
Incremental computation with differentiable programming https://arxiv.org/pdf/1312.0658.pdf
// To be shared on Swift Forums.
// Compile with:
// swiftc -Xllvm -enable-experimental-cross-file-derivative-registration -enable-experimental-forward-mode-differentiation main.swift
// MARK: - Make integers differentiable
extension Int: Differentiable {
public typealias TangentVector = Int
}
extension Int {
@derivative(of: +)
@usableFromInline
static func derivativeOfAdd(_ x: Int, _ y: Int) -> (value: Int, differential: (Int, Int) -> Int) {
(x + y, { dx, dy in dx + dy })
}
@derivative(of: +)
@usableFromInline
static func derivativeOfAdd(_ x: Int, _ y: Int) -> (value: Int, pullback: (Int) -> (Int, Int)) {
(x + y, { v in (v, v) })
}
}
// MARK: - Add missing derivatives
extension Array where Element: Differentiable {
@derivative(of: Array.differentiableReduce, wrt: (self, initialResult))
@usableFromInline
func derivativeOfDifferentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable (Result, Element) -> Result
) -> (
value: Result,
differential: (Array.TangentVector, Result.TangentVector) -> Result.TangentVector
) {
var differentials: [(Result.TangentVector, Element.TangentVector) -> Result.TangentVector] = []
let count = self.count
differentials.reserveCapacity(count)
var result = initialResult
for element in self {
let (y, df) = Swift.valueWithDifferential(at: result, element, in: nextPartialResult)
result = y
differentials.append(df)
}
return (value: result, differential: { dSelf, dInitial in
var dResult = dInitial
for (dElement, df) in zip(dSelf.base, differentials) {
dResult = df(dResult, dElement)
}
return dResult
})
}
@usableFromInline
@derivative(of: differentiableReduce)
internal func _vjpDifferentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable (Result, Element) -> Result
) -> (
value: Result,
pullback: (Result.TangentVector) -> (Array.TangentVector, Result.TangentVector)
) {
var pullbacks: [(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)] = []
let count = self.count
pullbacks.reserveCapacity(count)
var result = initialResult
for element in self {
let (y, pb) =
Swift.valueWithPullback(at: result, element, in: nextPartialResult)
result = y
pullbacks.append(pb)
}
return (value: result, pullback: { tangent in
var resultTangent = tangent
var elementTangents = TangentVector([])
elementTangents.base.reserveCapacity(count)
for pullback in pullbacks.reversed() {
let (newResultTangent, elementTangent) = pullback(resultTangent)
resultTangent = newResultTangent
elementTangents.base.append(elementTangent)
}
return (TangentVector(elementTangents.base.reversed()), resultTangent)
})
}
}
// MARK: - Formalize distances
protocol Distanceable: Differentiable {
func distance(to other: Self) -> TangentVector
}
extension Int: Distanceable {}
extension Array: Distanceable where Element: Differentiable, Element.TangentVector == Element {
func distance(to other: Array) -> TangentVector {
TangentVector(zip(other, self).map(-))
}
}
// MARK: - Generic incrementalizer
func incrementalized<T: Distanceable, U: Distanceable>(
at x: T, in body: @differentiable (T) -> U
) -> (T) -> U {
let (y, differential) = valueWithDifferential(at: x, in: body)
return { newX in
let Δx = x.distance(to: newX)
let Δy = differential(Δx)
var y = y
y.move(along: Δy)
return y
}
}
// MARK: - Examples
func twoPlus(_ x: Int) -> Int {
2 + x
}
// An incrementalized `twoPlus(_:)` function.
let incrementalizedTwoPlus = incrementalized(at: 3, in: twoPlus)
print(incrementalizedTwoPlus(10)) // 12
func sum(of elements: [Int]) -> Int {
elements.differentiableReduce(.zero, +)
}
// An incrementalized `sum(_:)` function.
let incrementalizedSum = incrementalized(at: [1, 2, 3, 4, 5], in: sum)
print(incrementalizedSum([10, 20, 30, 40, 50])) // 150
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment