Skip to content

Instantly share code, notes, and snippets.

import TensorFlow
protocol MyLayer: Differentiable {
associatedtype Input: Differentiable
@differentiable
func forward(_ x: Input) -> Tensor<Float>
}
extension MyLayer {
// MARK: - Some differentiable array manipulation functions used in the algorithms.
extension Array where Element: Differentiable {
@differentiable(vjp: _vjpSwappedAt)
func swappedAt(_ i: Int, _ j: Int) -> Array {
var tmp = self
tmp.swapAt(i, j)
return tmp
}
@marcrasi
marcrasi / dcm.md
Created May 30, 2019 20:01
Differentiating class methods

Differentiating Class Methods

marcrasi@

Last updated: 5/21/19

Problem

When c has class type with a method named f, Swift dispatches c.f() at runtime by looking for the concrete implementation of method inside a "vtable" referenced by c.

import TensorFlow
/// A Simple RNN Cell.
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
public var weight: Tensor<Scalar>
public var bias: Tensor<Scalar>
@noDerivative public var stateShape: TensorShape {
return TensorShape([1, weight.shape[1]])
}
protocol A {}
struct Wrapper<T> {
var t: T
}
extension Wrapper: A where T: A {}
func inner<T: A>(_ t: T) {}
struct MyTensor<T: Equatable & AdditiveArithmetic>: Equatable & AdditiveArithmetic {
var value: T
}
extension MyTensor : Differentiable where T : AdditiveArithmetic & Differentiable {
typealias TangentVector = MyTensor
typealias CotangentVector = MyTensor
typealias AllDifferentiableVariables = MyTensor
func tangentVector(from cotangentVector: CotangentVector) -> TangentVector {
return cotangentVector
Incorrect reconstructed type for $sxq_Iegnr_D
Original type:
(sil_function_type type=@differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1)
Reconstructed type:
(sil_function_type type=@callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1)
Stack dump:
0. Program arguments: /usr/local/google/home/marcrasi/swift-base-merge/build/Ninja-RelWithDebInfoAssert/swift-linux-x86_64/bin/swift -frontend -c -filelist /tmp/sources-716834 -supplementary-output-file-map /tmp/supplementaryOutputs-c055d6 -disable-objc-attr-requires-foundation-module -target x86_64-unknown-linux-gnu -disable-objc-interop -sdk / -I /usr/local/google/home/marcrasi/swift-base-merge/build/Ninja-RelWithDebInfoAssert/swift-linux-x86_64/./lib/swift/linux/x86_64 -warn-swift3-objc-inference-complete -warn-implicit-overrides -enable-library-evolution -g -module-cache-path /usr/local/google/home/marcrasi/swift-base-merge/build/Ninja-RelWithDebInfoAssert/swift-linux-x86_64/./module-cache -module-link-name swiftCore -nostdimport -parse-st
extension Array where Element: Differentiable {
/// Views the array as the differentiable product manifold of `Element` with itself `count` times.
public struct DifferentiableView: Differentiable {
/// The array that we are viewing.
public var base: [Element]
/// Construct a view of the given array.
public init(_ base: [Element]) { self.base = base }
// MARK: - Differentiable conformance.
public struct ProductSpaceVector<Element> {
public var elements: [Element]
public init(_ elements: [Element]) { self.elements = elements }
}
extension ProductSpaceVector : Equatable where Element : Equatable {
public static func == (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> Bool {
return lhs.elements == rhs.elements
}
}
public struct ProductSpaceVector<Element> {
public var elements: [Element]
public init(_ elements: [Element]) { self.elements = elements }
}
extension ProductSpaceVector : Equatable where Element : Equatable {
public static func == (lhs: ProductSpaceVector, rhs: ProductSpaceVector) -> Bool {
return lhs.elements == rhs.elements
}
}