Created
December 13, 2017 09:56
-
-
Save SnowWalkerJ/2c0137f914a44e1bbe5b9a158ab2ce44 to your computer and use it in GitHub Desktop.
Dilated LSTM module for pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
import torch.nn.functional as F | |
import torch.autograd as autograd | |
class DilatedLSTM(nn.Module): | |
""" | |
Dilated LSTM is inspired by dilated CNN (See WaveNet). By updating the hidden state | |
every n steps (n > 1), it gains ability to model longer-time relationships. | |
When dilation=1, it is like an ordinary LSTMCell, except that hidden states are managed | |
rather than exposed. | |
For dilation=n, n different states are kept. For t=1, n+1, 2n+1, ..., the first state is | |
used; for t=2, n+2, 2n+2,..., the second state is used, etc. | |
""" | |
def __init__(self, input_size, hidden_size, batch_size, dilation=1, output="new", cuda=False): | |
""" | |
Parameters | |
========== | |
input_size: int | |
The dimension of the input tensor | |
hidden_size: int | |
The dimension of the hidden state | |
batch_size: int | |
Since the hidden state is managed, batch_size is needed in order to initialize | |
the states | |
output: {'new', 'stack', 'concat'} | |
If 'new', returns only the newest state; elif 'stack', returns all the states | |
stacked as the last dimension; elif 'concat', returns all the states concatenated | |
as the last dimension | |
""" | |
super(DilatedLSTM, self).__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.batch_size = batch_size | |
self.cell = nn.LSTMCell(input_size, hidden_size) | |
self.initialize_weights() | |
self.dilation = dilation | |
self.use_cuda = cuda | |
self.t = 0 | |
self.states = None | |
self.output_format = output | |
if output not in ('new', 'stack', 'concat'): | |
raise ValueError("`output` must be one of {'new', 'stack', 'concat'}") | |
def initialize_weights(self): | |
cell = self.cell | |
n = self.hidden_size | |
for i in range(4): | |
init.orthogonal(cell.weight_ih[i*n:(i+1)*n].data) | |
init.orthogonal(cell.weight_hh[i*n:(i+1)*n].data) | |
init.constant(cell.bias_ih.data, 0.0) | |
init.constant(cell.bias_hh.data, 0.0) | |
init.constant(cell.bias_ih[n:2*n].data, 1.0) | |
init.constant(cell.bias_hh[n:2*n].data, 1.0) | |
def new_state(self): | |
h = torch.zeros(self.batch_size, self.hidden_size) | |
if self.use_cuda: | |
h = h.cuda() | |
h = autograd.Variable(h) | |
return h, h | |
def reset_states(self): | |
self.states = [self.new_state() for _ in range(self.dilation)] | |
self.t = 0 | |
def update_state(self, h1, c1): | |
self.states[self.t] = h1, c1 | |
self.t = (self.t + 1) % self.dilation | |
def forward(self, x): | |
""" | |
Parameters | |
========== | |
x: Variable[float, batch x input_size] | |
input features | |
Returns | |
======= | |
if output == 'new': | |
Variable[float, batch x hidden_size] | |
elif output == 'concat': | |
Variable[float, batch x hidden_size*dilation] | |
elif output == 'stack': | |
Variable[float, batch x hidden_size x dilation] | |
""" | |
h, c = self.states[self.t] | |
h1, c1 = self.cell(x, (h, c)) | |
self.update_state(h1, c1) | |
if self.output_format == 'concat': | |
x = torch.cat([self.states[(self.t-i) % self.dilation][0] for i in range(self.dilation)], 1) | |
return x | |
elif self.output_format == 'stack': | |
x = torch.stack([self.states[(self.t-i) % self.dilation][0] for i in range(self.dilation)], 2) | |
return x | |
else: | |
return h1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment