Last active
April 16, 2019 00:22
-
-
Save marcrasi/1028ae87014dff413e167142f08b298d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public struct ProductSpaceVector<Element> { | |
public var elements: [Element] | |
public init(_ elements: [Element]) { self.elements = elements } | |
} | |
extension ProductSpaceVector : Equatable where Element : Equatable { | |
public static func == (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> Bool { | |
return lhs.elements == rhs.elements | |
} | |
} | |
extension ProductSpaceVector : AdditiveArithmetic where Element : AdditiveArithmetic { | |
public static var zero: ProductSpaceVector<Element> { return ProductSpaceVector([]) } | |
public static func + (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> ProductSpaceVector { | |
var result: [Element] = [] | |
for i in 0..<max(lhs.elements.count, rhs.elements.count) { | |
let l = i < lhs.elements.count ? lhs.elements[i] : .zero | |
let r = i < rhs.elements.count ? rhs.elements[i] : .zero | |
result.append(l + r) | |
} | |
return ProductSpaceVector(result) | |
} | |
public static func - (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> ProductSpaceVector { | |
var result: [Element] = [] | |
for i in 0..<max(lhs.elements.count, rhs.elements.count) { | |
let l = i < lhs.elements.count ? lhs.elements[i] : .zero | |
let r = i < rhs.elements.count ? rhs.elements[i] : .zero | |
result.append(l - r) | |
} | |
return ProductSpaceVector(result) | |
} | |
} | |
extension ProductSpaceVector : Differentiable where Element : Differentiable { | |
public typealias TangentVector = ProductSpaceVector<Element.TangentVector> | |
public typealias CotangentVector = ProductSpaceVector<Element.CotangentVector> | |
public typealias AllDifferentiableVariables = ProductSpaceVector<Element.AllDifferentiableVariables> | |
public var allDifferentiableVariables: AllDifferentiableVariables { | |
get { | |
return elements.allDifferentiableVariables | |
} | |
set(v) { | |
elements.allDifferentiableVariables = v | |
} | |
} | |
public func moved(along direction: TangentVector) -> ProductSpaceVector { | |
fatalError("not implemented") | |
} | |
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector { | |
fatalError("not implemented") | |
} | |
} | |
extension Array : Differentiable where Element : Differentiable { | |
public typealias TangentVector = ProductSpaceVector<Element.TangentVector> | |
public typealias CotangentVector = ProductSpaceVector<Element.CotangentVector> | |
public typealias AllDifferentiableVariables = ProductSpaceVector<Element.AllDifferentiableVariables> | |
public var allDifferentiableVariables: AllDifferentiableVariables { | |
get { | |
return AllDifferentiableVariables(map { $0.allDifferentiableVariables }) | |
} | |
set(v) { | |
precondition(count == v.elements.count, "foo") | |
for i in indices { | |
self[i].allDifferentiableVariables = v.elements[i] | |
} | |
} | |
} | |
public func moved(along direction: TangentVector) -> Array { | |
fatalError("not implemented") | |
} | |
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector { | |
fatalError("not implemented") | |
} | |
} | |
// - MARK: "Sum first three" test. | |
extension Array where Element : Differentiable { | |
@differentiable(wrt: (self), vjp: vjpDifferentiableSubscript) | |
func differentiableSubscript(_ i: Int) -> Element { | |
return self[i] | |
} | |
func vjpDifferentiableSubscript(_ i: Int) -> (Element, (Element.CotangentVector) -> ProductSpaceVector<Element.CotangentVector>) { | |
let result = self[i] | |
func pullback(_ v: Element.CotangentVector) -> ProductSpaceVector<Element.CotangentVector> { | |
var r = Array<Element.CotangentVector>(repeating: .zero, count: i + 1) | |
r[i] = v | |
return ProductSpaceVector(r) | |
} | |
return (result, pullback) | |
} | |
} | |
func sumFirstThree(_ array: [Float]) -> Float { | |
return array.differentiableSubscript(0) + array.differentiableSubscript(1) + array.differentiableSubscript(2) | |
} | |
print(gradient(at: [0, 0, 0, 0], in: sumFirstThree)) | |
// - MARK: The "Parameter" test. | |
struct Parameter : Equatable { | |
@differentiable(wrt: (self), vjp: vjpX) | |
let x: Float | |
func vjpX() -> (Float, (Float) -> Parameter) { | |
return (x, { dx in Parameter(x: dx) } ) | |
} | |
} | |
extension Parameter { | |
func squared() -> Float { | |
return x * x | |
} | |
static func * (_ a: Parameter, _ b: Parameter) -> Float { | |
return a.x * b.x | |
} | |
} | |
extension Parameter : Differentiable, VectorNumeric { | |
typealias TangentVector = Parameter | |
typealias CotangentVector = Parameter | |
typealias Scalar = Float | |
typealias Shape = () | |
init(repeating repeatedValue: Float, shape: ()) { | |
self.init(x: repeatedValue) | |
} | |
static func + (lhs: Parameter, rhs: Parameter) -> Parameter { | |
return Parameter(x: lhs.x + rhs.x) | |
} | |
static func - (lhs: Parameter, rhs: Parameter) -> Parameter { | |
return Parameter(x: lhs.x - rhs.x) | |
} | |
static func * (lhs: Scalar, rhs: Parameter) -> Parameter { | |
return Parameter(x: lhs * rhs.x) | |
} | |
static var zero: Parameter { return Parameter(x: 0) } | |
} | |
func f(_ p: [Parameter]) -> Float { | |
return 100 * p.differentiableSubscript(0).squared() | |
} | |
print(gradient(at: [Parameter(x: 2)], in: f).elements) | |
print(gradient(at: [Parameter(x: 20)], in: f).elements) | |
// - MARK: Does it work with an optimizer? | |
struct MyModel : Layer { | |
var p: [Float] | |
init() { | |
p = [1, 2, 3, 4] | |
} | |
@differentiable | |
func applied(to input: Float, in context: Context) -> Float { | |
return 0 | |
} | |
} | |
extension Float { | |
func squared() -> Float { | |
return self * self | |
} | |
} | |
func loss(_ model: MyModel) -> Float { | |
return (model.p.differentiableSubscript(0) - model.p.differentiableSubscript(1)).squared() + | |
(model.p.differentiableSubscript(2) - model.p.differentiableSubscript(3)).squared() | |
} | |
import TensorFlow | |
var model = MyModel() | |
print(model) | |
print(loss(model)) | |
let grad = gradient(at: model, in: loss) | |
print(grad) | |
let optimizer = SGD<MyModel, Float>(learningRate: 0.1) | |
optimizer.update(&model.allDifferentiableVariables, along: grad) | |
print(model) | |
print(loss(model)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment