Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created May 15, 2019 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcrasi/bf5e3d192d750d04ca353a3dcf8fa023 to your computer and use it in GitHub Desktop.
Save marcrasi/bf5e3d192d750d04ca353a3dcf8fa023 to your computer and use it in GitHub Desktop.
struct MyTensor<T: Equatable & AdditiveArithmetic>: Equatable & AdditiveArithmetic {
var value: T
}
extension MyTensor : Differentiable where T : AdditiveArithmetic & Differentiable {
typealias TangentVector = MyTensor
typealias CotangentVector = MyTensor
typealias AllDifferentiableVariables = MyTensor
func tangentVector(from cotangentVector: CotangentVector) -> TangentVector {
return cotangentVector
}
}
struct Wrapper<T: Differentiable>: Differentiable {
var value: T
}
struct MyRNNCell<Scalar: AdditiveArithmetic & Differentiable>: Differentiable {
@differentiable
func call(_ input: Wrapper<MyTensor<Scalar>>) -> Wrapper<MyTensor<Scalar>> {
return Wrapper(value: input.value)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment