Skip to content

Instantly share code, notes, and snippets.

@jfsantos
Created February 18, 2017 15:18
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jfsantos/b111752633509cb81c51f155611ce4fb to your computer and use it in GitHub Desktop.
Save jfsantos/b111752633509cb81c51f155611ce4fb to your computer and use it in GitHub Desktop.
def train_fn(model, optimizer, criterion, batch):
x, y, lengths = batch
x = Variable(x.cuda())
y = Variable(y.cuda(), requires_grad=False)
mask = Variable(torch.ByteTensor(x.size()).fill_(1).cuda(),
requires_grad=False)
for k, l in enumerate(lengths):
mask[:l, k, :] = 0
hidden = model.init_hidden(x.size(1))
y_hat = model.forward(x, hidden)
# Apply mask
y_hat.masked_fill_(mask, 0.0)
y.masked_fill_(mask, 0.0)
loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.data[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment