Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created April 16, 2019 00:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcrasi/654589791a8a6683f00f0d08a9f4d579 to your computer and use it in GitHub Desktop.
Save marcrasi/654589791a8a6683f00f0d08a9f4d579 to your computer and use it in GitHub Desktop.
public struct ProductSpaceVector<Element> {
public var elements: [Element]
public init(_ elements: [Element]) { self.elements = elements }
}
extension ProductSpaceVector : Equatable where Element : Equatable {
public static func == (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> Bool {
return lhs.elements == rhs.elements
}
}
extension ProductSpaceVector : AdditiveArithmetic where Element : AdditiveArithmetic {
public static var zero: ProductSpaceVector<Element> { return ProductSpaceVector([]) }
public static func + (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> ProductSpaceVector {
var result: [Element] = []
for i in 0..<max(lhs.elements.count, rhs.elements.count) {
let l = i < lhs.elements.count ? lhs.elements[i] : .zero
let r = i < rhs.elements.count ? rhs.elements[i] : .zero
result.append(l + r)
}
return ProductSpaceVector(result)
}
public static func - (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> ProductSpaceVector {
var result: [Element] = []
for i in 0..<max(lhs.elements.count, rhs.elements.count) {
let l = i < lhs.elements.count ? lhs.elements[i] : .zero
let r = i < rhs.elements.count ? rhs.elements[i] : .zero
result.append(l - r)
}
return ProductSpaceVector(result)
}
}
extension ProductSpaceVector : Differentiable where Element : Differentiable {
public typealias TangentVector = ProductSpaceVector<Element.TangentVector>
public typealias CotangentVector = ProductSpaceVector<Element.CotangentVector>
public typealias AllDifferentiableVariables = ProductSpaceVector<Element.AllDifferentiableVariables>
public var allDifferentiableVariables: AllDifferentiableVariables {
get {
return elements.allDifferentiableVariables
}
set(v) {
elements.allDifferentiableVariables = v
}
}
public func moved(along direction: TangentVector) -> ProductSpaceVector {
fatalError("not implemented")
}
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector {
fatalError("not implemented")
}
}
extension Array : Differentiable where Element : Differentiable {
public typealias TangentVector = ProductSpaceVector<Element.TangentVector>
public typealias CotangentVector = ProductSpaceVector<Element.CotangentVector>
public typealias AllDifferentiableVariables = ProductSpaceVector<Element.AllDifferentiableVariables>
public var allDifferentiableVariables: AllDifferentiableVariables {
get {
return AllDifferentiableVariables(map { $0.allDifferentiableVariables })
}
set(v) {
precondition(count == v.elements.count, "foo")
for i in indices {
self[i].allDifferentiableVariables = v.elements[i]
}
}
}
public func moved(along direction: TangentVector) -> Array {
fatalError("not implemented")
}
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector {
fatalError("not implemented")
}
}
struct Foo : Differentiable & KeyPathIterable {
var scalar: Float = 42
var arr: [Float] = [1, 2, 3, 4]
}
let foo = Foo()
print(foo)
// => Foo(scalar: 42.0, arr: [1.0, 2.0, 3.0, 4.0])
print(foo.recursivelyAllWritableKeyPaths(to: Float.self))
// => [Swift.WritableKeyPath<kpi.Foo, Swift.Float>, Swift.WritableKeyPath<kpi.Foo, Swift.Float>, Swift.WritableKeyPath<kpi.Foo, Swift.Float>, Swift.WritableKeyPath<kpi.Foo, Swift.Float>, Swift.WritableKeyPath<kpi.Foo, Swift.Float>]
print(foo.allDifferentiableVariables)
// => AllDifferentiableVariables(scalar: 42.0, arr: kpi.ProductSpaceVector<Swift.Float>(elements: [1.0, 2.0, 3.0, 4.0]))
print(foo.allDifferentiableVariables.recursivelyAllWritableKeyPaths(to: Float.self))
// => [Swift.WritableKeyPath<kpi.Foo.AllDifferentiableVariables, Swift.Float>]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment