Skip to content

Instantly share code, notes, and snippets.

@bastings
Last active February 11, 2022 14:40
Show Gist options
  • Save bastings/f172fde8da08a9326966e25fe896b45f to your computer and use it in GitHub Desktop.
Save bastings/f172fde8da08a9326966e25fe896b45f to your computer and use it in GitHub Desktop.
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
class BiLSTMClassifier(nn.Module):
hidden_size: int
embedding_size: int
vocab_size: int
output_size: int
@nn.compact
def __call__(self, inputs, lengths):
"""Embeds and encodes the inputs, and then predicts."""
embedded = nn.Embed(
self.vocab_size,
features=self.embedding_size,
name='embedder')(
inputs)
_, (forward_final, backward_final) = BiLSTM(
self.hidden_size,
name='bilstm')(
embedded, lengths)
forward_output = nn.Dense(
self.output_size, use_bias=False, name='output_layer_fwd')(
forward_final)
backward_output = nn.Dense(
self.output_size, use_bias=False, name='output_layer_bwd')(
backward_final)
return forward_output + backward_output # Logits.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment