Created
May 16, 2019 00:55
-
-
Save marcrasi/5495924a7d07bb755ac2dc50d7abbe9a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import TensorFlow | |
/// A Simple RNN Cell. | |
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell { | |
public var weight: Tensor<Scalar> | |
public var bias: Tensor<Scalar> | |
@noDerivative public var stateShape: TensorShape { | |
return TensorShape([1, weight.shape[1]]) | |
} | |
public var zeroState: State { | |
return State(state: Tensor(zeros: stateShape)) | |
} | |
public struct State: Differentiable { | |
let state: Tensor<Scalar> | |
} | |
public typealias TimeStepInput = Tensor<Scalar> | |
public typealias TimeStepOutput = State | |
public typealias Input = RNNCellInput<TimeStepInput, State> | |
public typealias Output = RNNCellOutput<TimeStepOutput, State> | |
/// Creates a `SimpleRNNCell` with the specified input size and hidden state size. | |
/// | |
/// - Parameters: | |
/// - inputSize: The number of features in 2-D input tensors. | |
/// - hiddenSize: The number of features in 2-D hidden states. | |
public init(inputSize: Int, hiddenSize: Int) { | |
let concatenatedInputSize = inputSize + hiddenSize | |
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize]) | |
self.bias = Tensor(zeros: [hiddenSize]) | |
} | |
/// Returns the output obtained from applying the layer to the given input. | |
/// | |
/// - Parameters: | |
/// - input: The input to the layer. | |
/// - context: The contextual information for the layer application, e.g. the current learning | |
/// phase. | |
/// - Returns: The hidden state. | |
@differentiable | |
public func call(_ input: Input) -> Output { | |
let concatenatedInput = input.input.concatenated(with: input.state.state, alongAxis: 1) | |
let newState = State(state: tanh(matmul(concatenatedInput, weight) + bias)) | |
return Output(output: newState, state: newState) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment