Skip to content

Instantly share code, notes, and snippets.

View bastings's full-sized avatar

Jasmijn Bastings bastings

View GitHub Profile
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
class BiLSTMClassifier(nn.Module):
hidden_size: int
embedding_size: int
vocab_size: int
output_size: int
@nn.compact
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
class BiLSTM(nn.Module):
"""A simple bi-directional LSTM."""
hidden_size: int
@nn.compact
def __call__(self, inputs, lengths):
batch_size = inputs.shape[0]
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
class LSTM(nn.Module):
"""A simple unidirectional LSTM."""
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Tuple
from jax import numpy as jnp
from flax.linen import recurrent
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
Array = Any
This file has been truncated, but you can view the full file.
,
.
the
and
to
of
a
in
"
: