Skip to content

Instantly share code, notes, and snippets.

@SnowWalkerJ
Created December 13, 2017 09:56
Show Gist options
  • Save SnowWalkerJ/2c0137f914a44e1bbe5b9a158ab2ce44 to your computer and use it in GitHub Desktop.
Save SnowWalkerJ/2c0137f914a44e1bbe5b9a158ab2ce44 to your computer and use it in GitHub Desktop.
Dilated LSTM module for pytorch
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.autograd as autograd
class DilatedLSTM(nn.Module):
"""
Dilated LSTM is inspired by dilated CNN (See WaveNet). By updating the hidden state
every n steps (n > 1), it gains ability to model longer-time relationships.
When dilation=1, it is like an ordinary LSTMCell, except that hidden states are managed
rather than exposed.
For dilation=n, n different states are kept. For t=1, n+1, 2n+1, ..., the first state is
used; for t=2, n+2, 2n+2,..., the second state is used, etc.
"""
def __init__(self, input_size, hidden_size, batch_size, dilation=1, output="new", cuda=False):
"""
Parameters
==========
input_size: int
The dimension of the input tensor
hidden_size: int
The dimension of the hidden state
batch_size: int
Since the hidden state is managed, batch_size is needed in order to initialize
the states
output: {'new', 'stack', 'concat'}
If 'new', returns only the newest state; elif 'stack', returns all the states
stacked as the last dimension; elif 'concat', returns all the states concatenated
as the last dimension
"""
super(DilatedLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.batch_size = batch_size
self.cell = nn.LSTMCell(input_size, hidden_size)
self.initialize_weights()
self.dilation = dilation
self.use_cuda = cuda
self.t = 0
self.states = None
self.output_format = output
if output not in ('new', 'stack', 'concat'):
raise ValueError("`output` must be one of {'new', 'stack', 'concat'}")
def initialize_weights(self):
cell = self.cell
n = self.hidden_size
for i in range(4):
init.orthogonal(cell.weight_ih[i*n:(i+1)*n].data)
init.orthogonal(cell.weight_hh[i*n:(i+1)*n].data)
init.constant(cell.bias_ih.data, 0.0)
init.constant(cell.bias_hh.data, 0.0)
init.constant(cell.bias_ih[n:2*n].data, 1.0)
init.constant(cell.bias_hh[n:2*n].data, 1.0)
def new_state(self):
       h = torch.zeros(self.batch_size, self.hidden_size)
if self.use_cuda:
h = h.cuda()
       h = autograd.Variable(h)
return h, h
def reset_states(self):
self.states = [self.new_state() for _ in range(self.dilation)]
self.t = 0
def update_state(self, h1, c1):
self.states[self.t] = h1, c1
self.t = (self.t + 1) % self.dilation
def forward(self, x):
"""
Parameters
==========
x: Variable[float, batch x input_size]
input features
Returns
=======
if output == 'new':
Variable[float, batch x hidden_size]
elif output == 'concat':
Variable[float, batch x hidden_size*dilation]
elif output == 'stack':
Variable[float, batch x hidden_size x dilation]
"""
h, c = self.states[self.t]
h1, c1 = self.cell(x, (h, c))
self.update_state(h1, c1)
if self.output_format == 'concat':
x = torch.cat([self.states[(self.t-i) % self.dilation][0] for i in range(self.dilation)], 1)
return x
elif self.output_format == 'stack':
x = torch.stack([self.states[(self.t-i) % self.dilation][0] for i in range(self.dilation)], 2)
return x
else:
return h1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment