Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created March 21, 2021 10:47
Show Gist options
  • Save SannaPersson/f496588355a689a781e955bf373bc31f to your computer and use it in GitHub Desktop.
Save SannaPersson/f496588355a689a781e955bf373bc31f to your computer and use it in GitHub Desktop.
def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
loop = tqdm(train_loader, leave=True)
losses = []
for batch_idx, (x, y) in enumerate(loop):
x = x.to(config.DEVICE)
y0, y1, y2 = (
y[0].to(config.DEVICE),
y[1].to(config.DEVICE),
y[2].to(config.DEVICE),
)
with torch.cuda.amp.autocast():
out = model(x)
loss = (
loss_fn(out[0], y0, scaled_anchors[0])
+ loss_fn(out[1], y1, scaled_anchors[1])
+ loss_fn(out[2], y2, scaled_anchors[2])
)
losses.append(loss.item())
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# update progress bar
mean_loss = sum(losses) / len(losses)
loop.set_postfix(loss=mean_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment