Skip to content

Instantly share code, notes, and snippets.

@bastings
Created February 4, 2022 16:50
Show Gist options
  • Save bastings/710911bbf264690a91bc79923ad668bf to your computer and use it in GitHub Desktop.
Save bastings/710911bbf264690a91bc79923ad668bf to your computer and use it in GitHub Desktop.
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
class LSTM(nn.Module):
"""A simple unidirectional LSTM."""
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return MyLSTMCell(name='cell')(carry, x)
@staticmethod
def initialize_carry(batch_dims, hidden_size):
return MyLSTMCell.initialize_carry(
jax.random.PRNGKey(0), batch_dims, hidden_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment