Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created April 17, 2019 04:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcrasi/b7618e3e7a8a8a920b73e662a40c4c7b to your computer and use it in GitHub Desktop.
Save marcrasi/b7618e3e7a8a8a920b73e662a40c4c7b to your computer and use it in GitHub Desktop.
extension Array where Element: Differentiable {
/// Views the array as the differentiable product manifold of `Element` with itself `count` times.
public struct DifferentiableView: Differentiable {
/// The array that we are viewing.
public var base: [Element]
/// Construct a view of the given array.
public init(_ base: [Element]) { self.base = base }
// MARK: - Differentiable conformance.
public typealias TangentVector = Array<Element.TangentVector>.DifferentiableView
public typealias CotangentVector = Array<Element.CotangentVector>.DifferentiableView
public typealias AllDifferentiableVariables = Array<Element.AllDifferentiableVariables>.DifferentiableView
public var allDifferentiableVariables: AllDifferentiableVariables {
get {
return AllDifferentiableVariables(base.map { $0.allDifferentiableVariables })
}
set(v) {
precondition(base.count == v.base.count, "count mismatch")
for i in base.indices {
base[i].allDifferentiableVariables = v.base[i]
}
}
}
public func moved(along direction: TangentVector) -> DifferentiableView {
precondition(base.count == direction.base.count, "count mismatch")
return DifferentiableView(zip(base, direction.base).map { $0.moved(along: $1) })
}
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector {
precondition(base.count == cotangentVector.base.count, "count mismatch")
return TangentVector(zip(base, cotangentVector.base).map {
(selfElement, cotangentVectorElement) in
selfElement.tangentVector(from: cotangentVectorElement)
})
}
}
}
extension Array.DifferentiableView: KeyPathIterable {}
extension Array.DifferentiableView: Equatable where Element: Equatable {
public static func == (lhs: Array.DifferentiableView, rhs: Array.DifferentiableView) -> Bool {
return lhs.base == rhs.base
}
}
/// Makes `Array.DifferentiableView` additive as the product space.
///
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces of all counts.
extension Array.DifferentiableView: AdditiveArithmetic where Element: AdditiveArithmetic {
public static var zero: Array.DifferentiableView { return Array.DifferentiableView([]) }
public static func + (lhs: Array.DifferentiableView, rhs: Array.DifferentiableView) -> Array.DifferentiableView {
precondition(lhs.base.count == 0 || rhs.base.count == 0 || lhs.base.count == rhs.base.count, "count mismatch")
if lhs.base.count == 0 {
return rhs
}
if rhs.base.count == 0 {
return lhs
}
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
}
public static func - (lhs: Array.DifferentiableView, rhs: Array.DifferentiableView) -> Array.DifferentiableView {
precondition(lhs.base.count == 0 || rhs.base.count == 0 || lhs.base.count == rhs.base.count, "count mismatch")
if lhs.base.count == 0 {
return rhs
}
if rhs.base.count == 0 {
return lhs
}
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-))
}
}
/// Makes `Array` differentiable as the product manifold of `Element` with itself `count` times.
///
/// This is a convenience wrapper around `Array.DifferentiableView`.
extension Array: Differentiable where Element: Differentiable {
public typealias TangentVector = Array<Element.TangentVector>.DifferentiableView
public typealias CotangentVector = Array<Element.CotangentVector>.DifferentiableView
public typealias AllDifferentiableVariables = Array<Element.AllDifferentiableVariables>.DifferentiableView
public var allDifferentiableVariables: AllDifferentiableVariables {
get {
return DifferentiableView(self).allDifferentiableVariables
}
set(v) {
var view = DifferentiableView(self)
view.allDifferentiableVariables = v
self = view.base
}
}
public func moved(along direction: TangentVector) -> Array {
return DifferentiableView(self).moved(along: direction).base
}
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector {
return DifferentiableView(self).tangentVector(from: cotangentVector)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment