Skip to content

Instantly share code, notes, and snippets.

@Dref360
Created February 12, 2022 17:31
Show Gist options
  • Save Dref360/4fa7d6a807901145570bb02238b674cf to your computer and use it in GitHub Desktop.
Save Dref360/4fa7d6a807901145570bb02238b674cf to your computer and use it in GitHub Desktop.
Detection draft
from typing import List, Optional, Callable
import torch
from torch.optim import Adam
from torchvision.datasets.voc import VOCDetection
from torchvision.models.detection.ssd import ssd300_vgg16
from torchvision.transforms import Compose, Resize, ToTensor
from baal import ModelWrapper
CLASSES = ['aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'pottedplant',
'sheep',
'sofa',
'train',
'tvmonitor']
class TargetTransformers:
def __init__(self, classes: List[str], size=224):
self.size = size
self.classes = classes
def __call__(self, target):
target = target['annotation']
w, h = int(target['size']['width']), int(target['size']['height']),
boxes = [(int(obj['bndbox']['xmin']) / w * self.size,
int(obj['bndbox']['ymin']) / h * self.size,
int(obj['bndbox']['xmax']) / w * self.size,
int(obj['bndbox']['ymax']) / h * self.size,) for obj in target['object']]
labels = [self.classes.index(obj['name']) for obj in target['object']]
return {'boxes': torch.FloatTensor(boxes), 'labels': torch.LongTensor(labels)}
ds = VOCDetection('/data/dataset', download=False, target_transform=TargetTransformers(classes=CLASSES, size=224),
transform=Compose([Resize(224), ToTensor()]))
model = ssd300_vgg16(pretrained=True)
model.train()
x, y = ds[0]
print(model([x], [y]))
class MySSDCriterion:
def __call__(self, output, target):
return output['bbox_regression'] + output['classification']
class LocalizationWrapper(ModelWrapper):
def train_on_batch(
self, data, target, optimizer, cuda=False, regularizer: Optional[Callable] = None
):
output = self.model(data, target)
# Below should be as regular
loss = self.criterion(output, target)
loss.backward()
optimizer.step()
return loss
optim = Adam(params=model.parameters(), lr=0.001)
wrapper = LocalizationWrapper(model=model, criterion=MySSDCriterion(), replicate_in_memory=False)
print(wrapper.train_on_batch([x], [y], optimizer=optim, cuda=False))
wrapper.eval()
print(wrapper.predict_on_batch([x], cuda=False, iterations=2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment