Skip to content

Instantly share code, notes, and snippets.

@singhrahuldps
Created June 1, 2019 05:02
Show Gist options
  • Save singhrahuldps/b8f3904aff5c4e436c8fca13a99b59d6 to your computer and use it in GitHub Desktop.
Save singhrahuldps/b8f3904aff5c4e436c8fca13a99b59d6 to your computer and use it in GitHub Desktop.
def train(epochs = 10, bs = 64):
for epoch in range(epochs):
# training the model
i=0
total_loss = 0.0
ct = 0
while i < len(data[0]):
x1,x2,y = get_batch(data[0],i,i+bs)
i+=bs
ct+=1
user_ids = torch.LongTensor([user2idx[u] for u in x1]).cuda()
item_ids = torch.LongTensor([item2idx[b] for b in x2]).cuda()
y = torch.Tensor(y).cuda()
# disregard/zero the gradients from previous computation
model.zero_grad()
preds = model(user_ids,item_ids)
loss = loss_function(preds, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_loss /= ct
# getting the loss on validation set
i = 0
total_val_loss = 0.0
cv=0
m = model.eval() # setting the model to evaluation mode
while i < len(data[1]):
x11,x21,y1 = get_batch(data[1],i,i+bs)
i+=bs
cv+=1
user_ids = torch.LongTensor([user2idx[u] for u in x11]).cuda()
item_ids = torch.LongTensor([item2idx[b] for b in x21]).cuda()
y1 = torch.Tensor(y1).cuda()
preds = m(user_ids,item_ids)
loss = loss_function(preds, y1)
total_val_loss += loss.item()
total_val_loss /= cv
print('epoch', epoch+1, ' train loss', "%.3f" % total_loss,
' valid loss', "%.3f" % total_val_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment