Created
April 16, 2019 00:59
-
-
Save marcrasi/654589791a8a6683f00f0d08a9f4d579 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public struct ProductSpaceVector<Element> { | |
public var elements: [Element] | |
public init(_ elements: [Element]) { self.elements = elements } | |
} | |
extension ProductSpaceVector : Equatable where Element : Equatable { | |
public static func == (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> Bool { | |
return lhs.elements == rhs.elements | |
} | |
} | |
extension ProductSpaceVector : AdditiveArithmetic where Element : AdditiveArithmetic { | |
public static var zero: ProductSpaceVector<Element> { return ProductSpaceVector([]) } | |
public static func + (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> ProductSpaceVector { | |
var result: [Element] = [] | |
for i in 0..<max(lhs.elements.count, rhs.elements.count) { | |
let l = i < lhs.elements.count ? lhs.elements[i] : .zero | |
let r = i < rhs.elements.count ? rhs.elements[i] : .zero | |
result.append(l + r) | |
} | |
return ProductSpaceVector(result) | |
} | |
public static func - (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> ProductSpaceVector { | |
var result: [Element] = [] | |
for i in 0..<max(lhs.elements.count, rhs.elements.count) { | |
let l = i < lhs.elements.count ? lhs.elements[i] : .zero | |
let r = i < rhs.elements.count ? rhs.elements[i] : .zero | |
result.append(l - r) | |
} | |
return ProductSpaceVector(result) | |
} | |
} | |
extension ProductSpaceVector : Differentiable where Element : Differentiable { | |
public typealias TangentVector = ProductSpaceVector<Element.TangentVector> | |
public typealias CotangentVector = ProductSpaceVector<Element.CotangentVector> | |
public typealias AllDifferentiableVariables = ProductSpaceVector<Element.AllDifferentiableVariables> | |
public var allDifferentiableVariables: AllDifferentiableVariables { | |
get { | |
return elements.allDifferentiableVariables | |
} | |
set(v) { | |
elements.allDifferentiableVariables = v | |
} | |
} | |
public func moved(along direction: TangentVector) -> ProductSpaceVector { | |
fatalError("not implemented") | |
} | |
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector { | |
fatalError("not implemented") | |
} | |
} | |
extension Array : Differentiable where Element : Differentiable { | |
public typealias TangentVector = ProductSpaceVector<Element.TangentVector> | |
public typealias CotangentVector = ProductSpaceVector<Element.CotangentVector> | |
public typealias AllDifferentiableVariables = ProductSpaceVector<Element.AllDifferentiableVariables> | |
public var allDifferentiableVariables: AllDifferentiableVariables { | |
get { | |
return AllDifferentiableVariables(map { $0.allDifferentiableVariables }) | |
} | |
set(v) { | |
precondition(count == v.elements.count, "foo") | |
for i in indices { | |
self[i].allDifferentiableVariables = v.elements[i] | |
} | |
} | |
} | |
public func moved(along direction: TangentVector) -> Array { | |
fatalError("not implemented") | |
} | |
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector { | |
fatalError("not implemented") | |
} | |
} | |
struct Foo : Differentiable & KeyPathIterable { | |
var scalar: Float = 42 | |
var arr: [Float] = [1, 2, 3, 4] | |
} | |
let foo = Foo() | |
print(foo) | |
// => Foo(scalar: 42.0, arr: [1.0, 2.0, 3.0, 4.0]) | |
print(foo.recursivelyAllWritableKeyPaths(to: Float.self)) | |
// => [Swift.WritableKeyPath<kpi.Foo, Swift.Float>, Swift.WritableKeyPath<kpi.Foo, Swift.Float>, Swift.WritableKeyPath<kpi.Foo, Swift.Float>, Swift.WritableKeyPath<kpi.Foo, Swift.Float>, Swift.WritableKeyPath<kpi.Foo, Swift.Float>] | |
print(foo.allDifferentiableVariables) | |
// => AllDifferentiableVariables(scalar: 42.0, arr: kpi.ProductSpaceVector<Swift.Float>(elements: [1.0, 2.0, 3.0, 4.0])) | |
print(foo.allDifferentiableVariables.recursivelyAllWritableKeyPaths(to: Float.self)) | |
// => [Swift.WritableKeyPath<kpi.Foo.AllDifferentiableVariables, Swift.Float>] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment