Skip to content

Instantly share code, notes, and snippets.

@ehofesmann
Created April 16, 2023 21:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ehofesmann/15137c935472e59685d05b83f8b4e562 to your computer and use it in GitHub Desktop.
Save ehofesmann/15137c935472e59685d05b83f8b4e562 to your computer and use it in GitHub Desktop.
def setup_optimizer(model, lr, lr_backbone, weight_decay):
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=lr,
weight_decay=weight_decay)
return optimizer
def process_batch(model, batch, batch_num, device, pbar, running_loss):
pixel_values = batch["pixel_values"].to(device)
pixel_mask = batch["pixel_mask"].to(device)
labels = [{k: v.to(device) for k, v in t.items()} for t in batch["labels"]]
outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
loss = outputs.loss
# Logging the loss
running_loss += loss.item()
avg_loss = running_loss/(batch_num+1)
pbar.set_description("Average Loss %f" % avg_loss)
return outputs.loss, running_loss
def train_one_epoch(model, train_dataloader, device, optimizer, gradient_clip_val=0.1):
model.train()
running_loss = 0
pbar = tqdm(train_dataloader, unit=" batches")
for batch_num, batch in enumerate(pbar):
loss, running_loss = process_batch(model, batch, batch_num, device, pbar, running_loss)
optimizer.zero_grad()
loss.backward()
if gradient_clip_val > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val)
optimizer.step()
return running_loss/(batch_num+1)
def validate_one_epoch(model, val_dataloader, device):
model.eval()
running_vloss = 0
pbar = tqdm(val_dataloader, unit=" batches")
for batch_num, batch in enumerate(pbar):
loss, running_vloss = process_batch(model, batch, batch_num, device, pbar, running_vloss)
return running_vloss/(batch_num+1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment