Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created March 21, 2021 10:47
Show Gist options
  • Save SannaPersson/b6d8c261d30a1fb299a92766b66c13c2 to your computer and use it in GitHub Desktop.
Save SannaPersson/b6d8c261d30a1fb299a92766b66c13c2 to your computer and use it in GitHub Desktop.
def main():
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
optimizer = optim.Adam(
model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
)
loss_fn = YoloLoss()
scaler = torch.cuda.amp.GradScaler()
train_loader, test_loader, train_eval_loader = get_loaders(
train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv"
)
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
)
#Scale anchors to each prediction scale
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
for epoch in range(config.NUM_EPOCHS):
train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
if config.SAVE_MODEL:
save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
if epoch % 10 == 0 and epoch > 0:
print("On Test loader:")
check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
# Run model on test set and convert outputs to bounding boxes relative to image
pred_boxes, true_boxes = get_evaluation_bboxes(
test_loader,
model,
iou_threshold=config.NMS_IOU_THRESH,
anchors=config.ANCHORS,
threshold=config.CONF_THRESHOLD,
)
# Compute mean average precision
mapval = mean_average_precision(
pred_boxes,
true_boxes,
iou_threshold=config.MAP_IOU_THRESH,
box_format="midpoint",
num_classes=config.NUM_CLASSES,
)
print(f"MAP: {mapval.item()}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment