Skip to content

Instantly share code, notes, and snippets.

@ehofesmann
Last active April 20, 2023 22:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ehofesmann/d3deae192825c9b2dd9e1d686ca9291c to your computer and use it in GitHub Desktop.
Save ehofesmann/d3deae192825c9b2dd9e1d686ca9291c to your computer and use it in GitHub Desktop.
def create_dataloaders(train_view, val_view, processor):
train_dataset = FiftyOneCocoDetection(train_view, 'train', processor)
val_dataset = FiftyOneCocoDetection(val_view, 'val', processor)
collate_fn = collate_fn_generator(processor)
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=2)
return train_dataloader, val_dataloader
def get_run_id(dataset):
custom_id = dataset.info.get("max_prediction_field_id", 0)
custom_id += 1
dataset.info["max_prediction_field_id"] = custom_id
return custom_id
def init_wandb_run(dataset, custom_id):
wandb.init(
name=f"fiftyone_sweep_experiment_{custom_id}",
config={
"epochs": 20,
"architecture": "DETR",
"fiftyone_dataset": dataset.name,
"fiftyone_train_split_tag": "train",
"fiftyone_val_split_tag": "val",
"field_id": custom_id,
}
)
def save_view(dataset, pred_field, gt_field, eval_key, custom_id):
view = val_view.select_fields([pred_field, gt_field, eval_key+"_tp", eval_key+"_fn", eval_key+"_fp"])
view_name = f"wandb_run_{custom_id}"
dataset.save_view(view_name, view, overwrite=True)
def log_artifact(output_folder, custom_id):
art = wandb.Artifact(f'food-detector-{custom_id}', type="model")
art.add_file(os.path.join(output_folder, "final", "config.json"), "config.json")
art.add_file(os.path.join(output_folder, "final", "pytorch_model.bin"), "pytorch_model.bin")
def sweep_main():
dataset = fo.load_dataset("food")
train_view = dataset.match_tags("train")
val_view = dataset.match_tags("val")
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
train_dataloader, val_dataloader = create_dataloaders(train_view, val_view, processor)
model = DetrForObjectDetection.from_pretrained(
"facebook/detr-resnet-50",
revision="no_timm",
num_labels=len(dataset.distinct("ground_truth.detections.label")),
ignore_mismatched_sizes=True,
)
custom_id = get_run_id(dataset)
init_wandb_run(dataset, custom_id)
pred_field = f"predictions_{custom_id}"
gt_field = "ground_truth"
integrate_wandb_fo(dataset, pred_field, gt_field)
output_folder = f"sweep_results/{custom_id}"
train(model, train_dataloader, val_dataloader, output_folder=output_folder)
log_artifact(output_folder, custom_id)
add_detections(model, processor, val_view, pred_field)
eval_key = f"eval_{custom_id}"
results = fo.evaluate_detections(
val_view,
pred_field,
eval_key=eval_key,
compute_mAP=True,
)
save_view(dataset, pred_field, gt_field, eval_key, custom_id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment