Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created May 16, 2019 00:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcrasi/5495924a7d07bb755ac2dc50d7abbe9a to your computer and use it in GitHub Desktop.
Save marcrasi/5495924a7d07bb755ac2dc50d7abbe9a to your computer and use it in GitHub Desktop.
import TensorFlow
/// A Simple RNN Cell.
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
public var weight: Tensor<Scalar>
public var bias: Tensor<Scalar>
@noDerivative public var stateShape: TensorShape {
return TensorShape([1, weight.shape[1]])
}
public var zeroState: State {
return State(state: Tensor(zeros: stateShape))
}
public struct State: Differentiable {
let state: Tensor<Scalar>
}
public typealias TimeStepInput = Tensor<Scalar>
public typealias TimeStepOutput = State
public typealias Input = RNNCellInput<TimeStepInput, State>
public typealias Output = RNNCellOutput<TimeStepOutput, State>
/// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
///
/// - Parameters:
/// - inputSize: The number of features in 2-D input tensors.
/// - hiddenSize: The number of features in 2-D hidden states.
public init(inputSize: Int, hiddenSize: Int) {
let concatenatedInputSize = inputSize + hiddenSize
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize])
self.bias = Tensor(zeros: [hiddenSize])
}
/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameters:
/// - input: The input to the layer.
/// - context: The contextual information for the layer application, e.g. the current learning
/// phase.
/// - Returns: The hidden state.
@differentiable
public func call(_ input: Input) -> Output {
let concatenatedInput = input.input.concatenated(with: input.state.state, alongAxis: 1)
let newState = State(state: tanh(matmul(concatenatedInput, weight) + bias))
return Output(output: newState, state: newState)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment