Skip to content

Instantly share code, notes, and snippets.

@michael-iuzzolino
Created April 28, 2019 03:13
Show Gist options
  • Save michael-iuzzolino/9b4560c90d32deb99d9e62aa9033aa59 to your computer and use it in GitHub Desktop.
Save michael-iuzzolino/9b4560c90d32deb99d9e62aa9033aa59 to your computer and use it in GitHub Desktop.
Hello RNN driver
class HelloRNN(nn.Module):
cells = {
"LSTM" : LSTMCell,
"GRU" : GRUCell,
"vanilla" : VanillaCell
}
def __init__(self, num_chars, num_hidden=10, cell_type='LSTM'):
super().__init__()
self.cell_type = cell_type
print(f"Creating RNN with cell: {cell_type}")
self.cell = HelloRNN.cells[cell_type](num_chars, num_hidden)
self._init_weights()
def _init_weights(self):
for param in self.cell.parameters():
param.requires_grad_(True)
if param.data.ndimension() >= 2:
nn.init.xavier_uniform_(param.data)
else:
nn.init.zeros_(param.data)
def forward(self, X):
# Setup outputs container
outputs = torch.zeros_like(X)
# Iterate through sequence
self.cell.init()
for i, x in enumerate(X):
outputs[i] = self.cell(x)
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment