Skip to content

Instantly share code, notes, and snippets.

@gautham20
Created June 6, 2020 22:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautham20/15e12644b2779b6414bb72acf0a7bdf5 to your computer and use it in GitHub Desktop.
Save gautham20/15e12644b2779b6414bb72acf0a7bdf5 to your computer and use it in GitHub Desktop.
class EncoderDecoderWrapper(nn.Module):
def __init__(self, encoder, decoder_cell, output_size=3, teacher_forcing=0.3, sequence_len=336, decoder_input=True, device='cpu'):
super().__init__()
self.encoder = encoder
self.decoder_cell = decoder_cell
self.output_size = output_size
self.teacher_forcing = teacher_forcing
self.sequence_length = sequence_len
self.decoder_input = decoder_input
self.device = device
def forward(self, xb, yb=None):
if self.decoder_input:
decoder_input = xb[-1]
input_seq = xb[0]
if len(xb) > 2:
encoder_output, encoder_hidden = self.encoder(input_seq, *xb[1:-1])
else:
encoder_output, encoder_hidden = self.encoder(input_seq)
else:
if type(xb) is list and len(xb) > 1:
input_seq = xb[0]
encoder_output, encoder_hidden = self.encoder(*xb)
else:
input_seq = xb
encoder_output, encoder_hidden = self.encoder(input_seq)
prev_hidden = encoder_hidden
outputs = torch.zeros(input_seq.size(0), self.output_size, device=self.device)
y_prev = input_seq[:, -1, 0].unsqueeze(1)
for i in range(self.output_size):
step_decoder_input = torch.cat((y_prev, decoder_input[:, i]), axis=1)
if (yb is not None) and (i > 0) and (torch.rand(1) < self.teacher_forcing):
step_decoder_input = torch.cat((yb[:, i].unsqueeze(1), decoder_input[:, i]), axis=1)
rnn_output, prev_hidden = self.decoder_cell(prev_hidden, step_decoder_input)
y_prev = rnn_output
outputs[:, i] = rnn_output.squeeze(1)
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment