This file has been truncated, but you can view the full file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
, | |
. | |
the | |
and | |
to | |
of | |
a | |
in | |
" | |
: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
from typing import Any, Callable, Tuple | |
from jax import numpy as jnp | |
from flax.linen import recurrent | |
PRNGKey = Any | |
Shape = Tuple[int] | |
Dtype = Any | |
Array = Any |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
class LSTM(nn.Module): | |
"""A simple unidirectional LSTM.""" | |
@functools.partial( | |
nn.transforms.scan, | |
variable_broadcast='params', | |
in_axes=1, out_axes=1, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
class BiLSTM(nn.Module): | |
"""A simple bi-directional LSTM.""" | |
hidden_size: int | |
@nn.compact | |
def __call__(self, inputs, lengths): | |
batch_size = inputs.shape[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
class BiLSTMClassifier(nn.Module): | |
hidden_size: int | |
embedding_size: int | |
vocab_size: int | |
output_size: int | |
@nn.compact |