Skip to content

Instantly share code, notes, and snippets.

@adamoudad
Created March 20, 2021 21:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adamoudad/50c0089b9cebb735a666a348e6934b14 to your computer and use it in GitHub Desktop.
Save adamoudad/50c0089b9cebb735a666a348e6934b14 to your computer and use it in GitHub Desktop.
from tqdm import tqdm, trange
for i in trange(epochs, unit="epoch", desc="Train"):
model.train()
with tqdm(train_loader, desc="Train") as tbatch:
for i, (samples, targets) in enumerate(tbatch):
model.train()
samples = samples.to(device).long()
targets = targets.to(device)
model.zero_grad()
predictions, _ = model(samples.transpose(0, 1))
loss = criterion(predictions.squeeze(), targets.float())
acc = (predictions.round().squeeze() == targets).sum().item()
acc = acc / batch_size
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
batch_history["loss"].append(loss.item())
batch_history["accuracy"].append(acc)
tbatch.set_postfix(loss=sum(batch_history["loss"]) / len(batch_history["loss"]),
acc=sum(batch_history["accuracy"]) / len(batch_history["accuracy"]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment