Skip to content

Instantly share code, notes, and snippets.

@blythed
Created May 20, 2022 14:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save blythed/58de8f90516cdcdc550bba2305a0913e to your computer and use it in GitHub Desktop.
Save blythed/58de8f90516cdcdc550bba2305a0913e to your computer and use it in GitHub Desktop.
import padl
import torch
hidden_size = padl.param('hidden_size', 512)
input_size = padl.param('input_size', 64)
n_tokens = padl.param('n_tokens', 16)
nn = padl.transform(torch.nn)
@padl.transform
class HiddenState(torch.nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
def forward(self, x):
return self.layer(x)[0]
rnn_layer = nn.GRU(input_size, hidden_size, 1)
_pd_main = (
nn.Embedding(n_tokens, input_size)
>> HiddenState(rnn_layer)
>> nn.Linear(hidden_size, 1)
>> nn.Sigmoid()
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment