Comet ML + dataTap
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from comet_ml import Experiment | |
import torch | |
import torchvision | |
import torchvision.transforms as T | |
from datatap import Api | |
from datatap.torch import create_dataloader, torch_to_image_annotation | |
from datatap.utils.print_helpers import pprint | |
from datatap.metrics import ConfusionMatrix, PrecisionRecallCurve | |
from datatap.comet import init_experiment, log_validation_proposals | |
REPOSITORY = "_/aicrowd-food-recognition-challenge" | |
BATCH_SIZE = 2 | |
COMET_PROJECT = "your project name" | |
COMET_WORKSPACE = "your workspace name" | |
def main(): | |
experiment = Experiment( | |
project_name = COMET_PROJECT, | |
workspace = COMET_WORKSPACE, | |
auto_output_logging = False, | |
) | |
api = Api() | |
database = api.get_default_database() | |
repository = database.get_repository(REPOSITORY) | |
dataset = repository.get_dataset("latest") | |
init_experiment(experiment, dataset) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
classes = list(dataset.template.classes.keys()) | |
classes_with_background = ["__background__"] + classes | |
class_map = { | |
cls: i + 1 | |
for i, cls in enumerate(classes) | |
} | |
class_map_with_background = { | |
cls: i | |
for i, cls in enumerate(classes_with_background) | |
} | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn( | |
pretrained=False, | |
pretrained_backbone=True, | |
num_classes=len(classes_with_background), | |
).to(device) | |
params = [p for p in model.parameters() if p.requires_grad] | |
optimizer = torch.optim.SGD( | |
params, | |
lr=0.001, | |
momentum=0.9, | |
weight_decay=0.0005 | |
) | |
lr_scheduler = torch.optim.lr_scheduler.StepLR( | |
optimizer, | |
step_size=3, | |
gamma=0.1 | |
) | |
image_transform = T.Compose([ | |
T.Resize(100), | |
T.ToTensor(), | |
]) | |
training_data = create_dataloader( | |
dataset, | |
"training", | |
batch_size = BATCH_SIZE, | |
image_transform = image_transform, | |
device = torch.device("cpu"), | |
class_mapping = class_map_with_background | |
) | |
validation_data = create_dataloader( | |
dataset, | |
"validation", | |
batch_size = BATCH_SIZE, | |
image_transform = image_transform, | |
device = "cpu", | |
class_mapping = class_map_with_background | |
) | |
confusion_matrix = ConfusionMatrix(classes = classes) | |
pr_curve = PrecisionRecallCurve() | |
num_epochs = 4 | |
for i in range(num_epochs): | |
model.train() | |
pprint("Starting epoch: {green}{0}", i) | |
count = 0 | |
for batch in training_data: | |
optimizer.zero_grad() | |
targets = [ | |
{ | |
"boxes": boxes.to(device), | |
"labels": labels.to(device) | |
} | |
for boxes, labels in zip(batch.boxes, batch.labels) | |
] | |
images = [image.to(device) for image in batch.images] | |
losses = model(images, targets) | |
total_loss = sum([val for val in losses.values()]) | |
total_loss.backward() | |
optimizer.step() | |
count += 1 | |
if count % 10 == 0: | |
pprint( | |
"Iter: {yellow}{0:4d}{clear}, Classifier loss: {yellow}{1:2.3f}{clear}, Total loss: {yellow}{2:2.3f}{clear}", | |
count, | |
losses["loss_classifier"], | |
total_loss | |
) | |
model.eval() | |
all_annotations = [] | |
for batch in validation_data: | |
predictions = model([image.to(device) for image in batch.images]) | |
annotations = [ | |
torch_to_image_annotation( | |
image, | |
class_map, | |
labels = prediction["labels"], | |
scores = prediction["scores"], | |
boxes = prediction["boxes"], | |
uid = ground_truth.uid, | |
) | |
for image, ground_truth, prediction in zip(batch.images, batch.original_annotations, predictions) | |
] | |
all_annotations += annotations | |
confusion_matrix.batch_add_annotation( | |
batch.original_annotations, | |
annotations, | |
iou_threshold = 0.5, | |
confidence_threshold = 0.5 | |
) | |
pr_curve.batch_add_annotation( | |
batch.original_annotations, | |
annotations, | |
iou_threshold = 0.5, | |
) | |
log_validation_proposals(experiment, all_annotations) | |
lr_scheduler.step() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment