-
-
Save rajy4683/c23474531a80c0fff223f105246abc76 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Training loop | |
""" | |
def train(model, iterator, optimizer, criterion, clip): | |
model.train() | |
epoch_loss = 0 | |
for i, batch in enumerate(iterator): | |
src = batch.src | |
trg = batch.trg | |
optimizer.zero_grad() | |
output, _ = model(src, trg[:,:-1]) | |
#output = [batch size, trg len - 1, output dim] | |
#trg = [batch size, trg len] | |
output_dim = output.shape[-1] | |
output = output.contiguous().view(-1, output_dim) | |
trg = trg[:,1:].contiguous().view(-1) | |
#output = [batch size * trg len - 1, output dim] | |
#trg = [batch size * trg len - 1] | |
loss = criterion(output, trg) | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip) | |
optimizer.step() | |
epoch_loss += loss.item() | |
return epoch_loss / len(iterator) | |
""" | |
Evaluation loop | |
""" | |
def evaluate(model, iterator, criterion): | |
model.eval() | |
epoch_loss = 0 | |
with torch.no_grad(): | |
for i, batch in enumerate(iterator): | |
src = batch.src | |
trg = batch.trg | |
output, _ = model(src, trg[:,:-1]) | |
#output = [batch size, trg len - 1, output dim] | |
#trg = [batch size, trg len] | |
output_dim = output.shape[-1] | |
output = output.contiguous().view(-1, output_dim) | |
trg = trg[:,1:].contiguous().view(-1) | |
#output = [batch size * trg len - 1, output dim] | |
#trg = [batch size * trg len - 1] | |
loss = criterion(output, trg) | |
epoch_loss += loss.item() | |
return epoch_loss / len(iterator) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment