Skip to content

Instantly share code, notes, and snippets.

@Tony363
Last active May 31, 2022 06:12
Show Gist options
  • Save Tony363/c7b421d3a302bc23b007b9778c6a24d6 to your computer and use it in GitHub Desktop.
Save Tony363/c7b421d3a302bc23b007b9778c6a24d6 to your computer and use it in GitHub Desktop.
def torch_training(
ds_train: Streamer,
ds_val: Streamer,
model: nn.Module,
optimizer: torch.optim.Optimizer,
epochs: int = 1000,
criterion=nn.BCELoss(),
f1_torch=f1_score # F1_Loss().cuda(),
) -> nn.Module:
f1_callback, loss_callback, val_f1_callback, val_loss_callback = (), (), (), ()
for epoch in range(epochs):
minibatch_loss_train, minibatch_f1 = 0, 0
n_train = None
for n_train, (x, y) in enumerate(ds_train):
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y)
f1_score = f1_torch((y_pred.cpu() > 0.5).int(), y.cpu().int())
loss.backward()
optimizer.step()
minibatch_loss_train += loss.item()
minibatch_f1 += f1_score
model.eval()
loss_callback += (minibatch_loss_train / (n_train)
if n_train else minibatch_loss_train,)
f1_callback += ((minibatch_f1/(n_train)
if n_train else minibatch_f1)/2,)
with torch.no_grad():
minibatch_loss_val, minibatch_f1_val = 0, 0
for n_val, (x, y) in enumerate(ds_val):
x = x.to(device)
y = y.to(device)
y_pred = model(x)
loss = criterion(y_pred, y)
val_f1 = f1_torch((y_pred.cpu() > 0.5).int(), y.cpu().int())
minibatch_loss_val += loss.item()
minibatch_f1_val += val_f1
val_loss_callback += (minibatch_loss_val /
(n_val ) if n_val else minibatch_loss_val,)
val_f1_callback += ((minibatch_f1_val/(n_val )
if n_val else minibatch_f1)/2,)
logging.info(
"Epoch: {}/{}, Loss - {:.3f},f1 - {:.3f}, val_loss - {:3f}, val_f1 - {:3f}".format(
epoch, epochs, loss_callback[-1], f1_callback[-1], val_loss_callback[-1], val_f1_callback[-1]
))
if not bool(epoch % 10):
save_checkpoint('model.pth', model,
optimizer, loss_callback[-1], f1_callback[-1], val_loss_callback[-1], val_f1_callback[-1])
save_metrics('metrics.pth', loss_callback, f1_callback,
val_loss_callback, val_f1_callback)
model.train()
random.shuffle(embedding_list_train)
ds_train.new_embeddings(embedding_list_train)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment