-
-
Save SannaPersson/b6d8c261d30a1fb299a92766b66c13c2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE) | |
optimizer = optim.Adam( | |
model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY | |
) | |
loss_fn = YoloLoss() | |
scaler = torch.cuda.amp.GradScaler() | |
train_loader, test_loader, train_eval_loader = get_loaders( | |
train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv" | |
) | |
if config.LOAD_MODEL: | |
load_checkpoint( | |
config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE | |
) | |
#Scale anchors to each prediction scale | |
scaled_anchors = ( | |
torch.tensor(config.ANCHORS) | |
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
).to(config.DEVICE) | |
for epoch in range(config.NUM_EPOCHS): | |
train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors) | |
if config.SAVE_MODEL: | |
save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar") | |
if epoch % 10 == 0 and epoch > 0: | |
print("On Test loader:") | |
check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD) | |
# Run model on test set and convert outputs to bounding boxes relative to image | |
pred_boxes, true_boxes = get_evaluation_bboxes( | |
test_loader, | |
model, | |
iou_threshold=config.NMS_IOU_THRESH, | |
anchors=config.ANCHORS, | |
threshold=config.CONF_THRESHOLD, | |
) | |
# Compute mean average precision | |
mapval = mean_average_precision( | |
pred_boxes, | |
true_boxes, | |
iou_threshold=config.MAP_IOU_THRESH, | |
box_format="midpoint", | |
num_classes=config.NUM_CLASSES, | |
) | |
print(f"MAP: {mapval.item()}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment