-
-
Save ehofesmann/d3deae192825c9b2dd9e1d686ca9291c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_dataloaders(train_view, val_view, processor): | |
train_dataset = FiftyOneCocoDetection(train_view, 'train', processor) | |
val_dataset = FiftyOneCocoDetection(val_view, 'val', processor) | |
collate_fn = collate_fn_generator(processor) | |
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True) | |
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=2) | |
return train_dataloader, val_dataloader | |
def get_run_id(dataset): | |
custom_id = dataset.info.get("max_prediction_field_id", 0) | |
custom_id += 1 | |
dataset.info["max_prediction_field_id"] = custom_id | |
return custom_id | |
def init_wandb_run(dataset, custom_id): | |
wandb.init( | |
name=f"fiftyone_sweep_experiment_{custom_id}", | |
config={ | |
"epochs": 20, | |
"architecture": "DETR", | |
"fiftyone_dataset": dataset.name, | |
"fiftyone_train_split_tag": "train", | |
"fiftyone_val_split_tag": "val", | |
"field_id": custom_id, | |
} | |
) | |
def save_view(dataset, pred_field, gt_field, eval_key, custom_id): | |
view = val_view.select_fields([pred_field, gt_field, eval_key+"_tp", eval_key+"_fn", eval_key+"_fp"]) | |
view_name = f"wandb_run_{custom_id}" | |
dataset.save_view(view_name, view, overwrite=True) | |
def log_artifact(output_folder, custom_id): | |
art = wandb.Artifact(f'food-detector-{custom_id}', type="model") | |
art.add_file(os.path.join(output_folder, "final", "config.json"), "config.json") | |
art.add_file(os.path.join(output_folder, "final", "pytorch_model.bin"), "pytorch_model.bin") | |
def sweep_main(): | |
dataset = fo.load_dataset("food") | |
train_view = dataset.match_tags("train") | |
val_view = dataset.match_tags("val") | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
train_dataloader, val_dataloader = create_dataloaders(train_view, val_view, processor) | |
model = DetrForObjectDetection.from_pretrained( | |
"facebook/detr-resnet-50", | |
revision="no_timm", | |
num_labels=len(dataset.distinct("ground_truth.detections.label")), | |
ignore_mismatched_sizes=True, | |
) | |
custom_id = get_run_id(dataset) | |
init_wandb_run(dataset, custom_id) | |
pred_field = f"predictions_{custom_id}" | |
gt_field = "ground_truth" | |
integrate_wandb_fo(dataset, pred_field, gt_field) | |
output_folder = f"sweep_results/{custom_id}" | |
train(model, train_dataloader, val_dataloader, output_folder=output_folder) | |
log_artifact(output_folder, custom_id) | |
add_detections(model, processor, val_view, pred_field) | |
eval_key = f"eval_{custom_id}" | |
results = fo.evaluate_detections( | |
val_view, | |
pred_field, | |
eval_key=eval_key, | |
compute_mAP=True, | |
) | |
save_view(dataset, pred_field, gt_field, eval_key, custom_id) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment