Skip to content

Instantly share code, notes, and snippets.

@mhr
Last active August 14, 2018 19:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mhr/2b214985c1bafdbba971590878c83fe3 to your computer and use it in GitHub Desktop.
Save mhr/2b214985c1bafdbba971590878c83fe3 to your computer and use it in GitHub Desktop.
def find_indices(condition, list_):
return [i for i, x in enumerate(list_) if condition(x)]
class Network(nn.Module):
def __init__(self, vocab_dim, embed_dim, input_dim, hidden_dim=512):
self.embedding = nn.Embedding(vocab_dim, embed_dim)
self.rnn = nn.LSTM(embed_dim+input_dim,
hidden_dim,
batch_first=True)
self.out = nn.Linear(hidden_dim, vocab_dim)
self.act = nn.Softmax(dim=1)
self.hidden_dim = hidden_dim
def decode_step(self, embeddings_t, inputs_t, h_t, c_t):
x_t = torch.cat((embeddings_t, inputs_t), dim=-1)
lstm_outputs_t, (h_t, c_t) = self.rnn(x_t, (h_t, c_t))
outputs_t = self.act(self.out(lstm_outputs_t.squeeze(1)))
return outputs_t, h_t, c_t
def forward(self,
text,
inputs,
lengths):
batch_size = text.size(0)
h = torch.randn([1, batch_size, self.hidden_dim])
c = torch.randn([1, batch_size, self.hidden_dim])
embeddings = self.embedding(text)
predictions = torch.zeros((batch_size, max(lengths), self.vocab_dim)).to(DEVICE)
for t in range(max(lengths)):
# filter and keep only samples whose <END> has not come yet
idx = find_indices(lambda x: x > t, lengths)
inputs_t = inputs[idx]
embeddings_t = embeddings[idx, t, :].unsqueeze(1)
# slice hidden states according to input
h_t = h[:, idx]
c_t = c[:, idx]
predictions_t, h_t, c_t = self.decode_step(embeddings_t,
inputs_t,
h_t,
c_t)
# update hidden states from decode step
h[:, idx] = copy.copy(h_t) # breaks backprop, but fixes memory leak
c[:, idx] = copy.copy(c_t) # breaks backprop, but fixes memory leak
predictions[idx, t] = copy.copy(predictions_t) # breaks backprop, but fixes memory leak
return predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment