-
-
Save ehofesmann/15137c935472e59685d05b83f8b4e562 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def setup_optimizer(model, lr, lr_backbone, weight_decay): | |
param_dicts = [ | |
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, | |
{ | |
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], | |
"lr": lr_backbone, | |
}, | |
] | |
optimizer = torch.optim.AdamW(param_dicts, lr=lr, | |
weight_decay=weight_decay) | |
return optimizer | |
def process_batch(model, batch, batch_num, device, pbar, running_loss): | |
pixel_values = batch["pixel_values"].to(device) | |
pixel_mask = batch["pixel_mask"].to(device) | |
labels = [{k: v.to(device) for k, v in t.items()} for t in batch["labels"]] | |
outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) | |
loss = outputs.loss | |
# Logging the loss | |
running_loss += loss.item() | |
avg_loss = running_loss/(batch_num+1) | |
pbar.set_description("Average Loss %f" % avg_loss) | |
return outputs.loss, running_loss | |
def train_one_epoch(model, train_dataloader, device, optimizer, gradient_clip_val=0.1): | |
model.train() | |
running_loss = 0 | |
pbar = tqdm(train_dataloader, unit=" batches") | |
for batch_num, batch in enumerate(pbar): | |
loss, running_loss = process_batch(model, batch, batch_num, device, pbar, running_loss) | |
optimizer.zero_grad() | |
loss.backward() | |
if gradient_clip_val > 0: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val) | |
optimizer.step() | |
return running_loss/(batch_num+1) | |
def validate_one_epoch(model, val_dataloader, device): | |
model.eval() | |
running_vloss = 0 | |
pbar = tqdm(val_dataloader, unit=" batches") | |
for batch_num, batch in enumerate(pbar): | |
loss, running_vloss = process_batch(model, batch, batch_num, device, pbar, running_vloss) | |
return running_vloss/(batch_num+1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment