Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created June 4, 2019 17:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcrasi/989b2014416bc2948ed1c03b1352e5c8 to your computer and use it in GitHub Desktop.
Save marcrasi/989b2014416bc2948ed1c03b1352e5c8 to your computer and use it in GitHub Desktop.
// MARK: - Some differentiable array manipulation functions used in the algorithms.
extension Array where Element: Differentiable {
@differentiable(vjp: _vjpSwappedAt)
func swappedAt(_ i: Int, _ j: Int) -> Array {
var tmp = self
tmp.swapAt(i, j)
return tmp
}
func _vjpSwappedAt(_ i: Int, _ j: Int) -> (Array, (TangentVector) -> TangentVector) {
return (swappedAt(i, j), { TangentVector($0.base.swappedAt(i, j)) })
}
@differentiable(vjp: _vjpDroppedFirst)
func droppedFirst() -> Array {
return Array(self.dropFirst())
}
func _vjpDroppedFirst() -> (Array, (TangentVector) -> TangentVector) {
return (droppedFirst(), { TangentVector([Element.TangentVector.zero] + $0.base) })
}
@differentiable(vjp: _vjpAppending)
func appending(_ element: Element) -> Array {
var tmp = self
tmp.append(element)
return tmp
}
func _vjpAppending(_ element: Element) -> ([Element], (TangentVector) -> (TangentVector, Element.TangentVector)) {
func pb(_ v: TangentVector) -> (TangentVector, Element.TangentVector) {
return (TangentVector(Array<Element.TangentVector>(v.base.dropLast())), v.base[v.base.count - 1])
}
return (appending(element), pb)
}
@differentiable(vjp: _vjpMakeSingle)
static func makeSingle(_ element: Element) -> Array {
return [element]
}
static func _vjpMakeSingle(_ element: Element) -> (Array, (TangentVector) -> Element.TangentVector) {
return ([element], { v in
precondition(v.base.count == 1)
return v.base[0]
})
}
}
// MARK: - Custom VJP for stdlib sort.
@differentiable(vjp: _vjpSorted)
func sorted(_ array: [Double]) -> [Double] {
return array.sorted()
}
func _vjpSorted(_ array: [Double]) -> ([Double], (Array<Double>.DifferentiableView) -> Array<Double>.DifferentiableView) {
let sort = array.enumerated().sorted(by: { $0.element < $1.element })
let sorted = sort.map { $0.element }
let permutation = sort.map { $0.offset }
return (sorted, { v in
var result = Array(repeating: 0.0, count: v.base.count)
for (i, j) in permutation.enumerated() {
result[j] = v.base[i]
}
return Array<Double>.DifferentiableView(result)
})
}
let arrayToSort: [Double] = [7, 2, 4, 1, 8, 3, 0, 9]
var vectorsToPullBack: [[Double]] = []
for i in 0..<arrayToSort.count {
var v = Array(repeating: 0.0, count: arrayToSort.count)
v[i] = 1
vectorsToPullBack.append(v)
}
let (value, pb) = valueWithPullback(at: arrayToSort, in: sorted)
print("USING CUSTOM DERIVATIVE FOR SORT")
print(value)
for v in vectorsToPullBack {
print(pb(Array.DifferentiableView(v)))
}
print("")
// MARK: - Selection sort.
func argMax(_ array: [Double]) -> Int {
var result: Int = 0
var max: Double = array[0]
for (index, val) in array.enumerated() {
if val > max {
result = index
max = val
}
}
return result
}
func selectionSort(_ array: [Double]) -> [Double] {
if array.count <= 1 {
return array
} else {
let next = array.swappedAt(0, argMax(array.withoutDerivative()))
return selectionSort(next.droppedFirst()).appending(next[0])
}
}
let (value2, pb2) = valueWithPullback(at: arrayToSort, in: selectionSort)
print("USING AUTOMATICALLY COMPUTED DERIVATIVE OF SELECTION SORT")
print(value2)
if value2 != value {
print(" oh no, that one is wrong")
}
for v in vectorsToPullBack {
print(pb2(Array.DifferentiableView(v)))
if pb2(Array.DifferentiableView(v)) != pb(Array.DifferentiableView(v)) {
print(" oh no, that one is wrong")
}
}
print("")
// MARK: - Quicksort.
extension Array where Element : Differentiable {
func filter(_ predicate: (Element) -> Bool, _ start: Int) -> Array {
if start == count {
return []
}
if predicate(self[start]) {
return filter(predicate, start + 1).appending(self[start])
} else {
return filter(predicate, start + 1)
}
}
}
func qsort(_ array: [Double]) -> [Double] {
if array.count <= 1 {
return array
}
let pivot = array[0]
let pivotWD = pivot.withoutDerivative()
let l = array.filter({ $0 < pivotWD }, 1)
let r = array.filter({ $0 >= pivotWD }, 1)
return qsort(l) + Array.makeSingle(pivot) + qsort(r)
}
let (value3, pb3) = valueWithPullback(at: arrayToSort, in: qsort)
print("USING AUTOMATICALLY COMPUTED DERIVATIVE OF QUICK SORT")
print(value3)
if value3 != value {
print(" oh no, that one is wrong")
}
for v in vectorsToPullBack {
print(pb3(Array.DifferentiableView(v)))
if pb3(Array.DifferentiableView(v)) != pb(Array.DifferentiableView(v)) {
print(" oh no, that one is wrong")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment