Skip to content

Instantly share code, notes, and snippets.

View dasguptar's full-sized avatar

Riddhiman Dasgupta dasguptar

View GitHub Profile
def train_fn(model, optimizer, criterion, batch):
x, y, lengths = batch
x = Variable(x.cuda())
y = Variable(y.cuda(), requires_grad=False)
mask = Variable(torch.ByteTensor(x.size()).fill_(1).cuda(),
requires_grad=False)
for k, l in enumerate(lengths):
mask[:l, k, :] = 0