Created
May 15, 2019 23:54
-
-
Save marcrasi/bf5e3d192d750d04ca353a3dcf8fa023 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct MyTensor<T: Equatable & AdditiveArithmetic>: Equatable & AdditiveArithmetic { | |
var value: T | |
} | |
extension MyTensor : Differentiable where T : AdditiveArithmetic & Differentiable { | |
typealias TangentVector = MyTensor | |
typealias CotangentVector = MyTensor | |
typealias AllDifferentiableVariables = MyTensor | |
func tangentVector(from cotangentVector: CotangentVector) -> TangentVector { | |
return cotangentVector | |
} | |
} | |
struct Wrapper<T: Differentiable>: Differentiable { | |
var value: T | |
} | |
struct MyRNNCell<Scalar: AdditiveArithmetic & Differentiable>: Differentiable { | |
@differentiable | |
func call(_ input: Wrapper<MyTensor<Scalar>>) -> Wrapper<MyTensor<Scalar>> { | |
return Wrapper(value: input.value) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment