Skip to content

Instantly share code, notes, and snippets.

@rxwei
Last active October 7, 2019 22:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rxwei/99214f5bf3df44496f0a016c3c0a25f7 to your computer and use it in GitHub Desktop.
Save rxwei/99214f5bf3df44496f0a016c3c0a25f7 to your computer and use it in GitHub Desktop.
Zero tangent vector initializer
import TensorFlow
protocol Differentiable {
...
var zeroTangentVectorInitializer: () -> Self { get }
}
extension Tensor where Scalar: TensorFlowFloatingPoint {
var zeroTangentVectorInitializer: () -> Self {
{ [shape = self.shape] in Tensor(zeros: shape) }
}
}
struct Foo : Differentiable {
var x: Tensor<Float>
var y: Tensor<Float>
@noDerivative var flag: Bool
var zeroTangentVectorInitializer: () -> TangentVector {
let xTanInit = x.zeroTangentVectorInitializer
let yTanInit = y.zeroTangentVectorInitializer
return { TangentVector(x: xTanInit(), y: yTanInit()) }
}
}
func foo(_ a: Foo) -> Tensor<Float> {
a.y
}
@differentiating(foo)
func foo_vjp(_ a: Foo) -> (value: Tensor<Float>, pullback: (Tensor<Float>) -> Foo.TangentVector) {
let xTanInit = a.x.zeroTangentVectorInitializer
return (value: foo(a), pullback: { v in Foo.TangentVector(x: xTanInit(), y: v) })
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment