Created
February 12, 2022 17:31
-
-
Save Dref360/4fa7d6a807901145570bb02238b674cf to your computer and use it in GitHub Desktop.
Detection draft
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional, Callable | |
import torch | |
from torch.optim import Adam | |
from torchvision.datasets.voc import VOCDetection | |
from torchvision.models.detection.ssd import ssd300_vgg16 | |
from torchvision.transforms import Compose, Resize, ToTensor | |
from baal import ModelWrapper | |
CLASSES = ['aeroplane', | |
'bicycle', | |
'bird', | |
'boat', | |
'bottle', | |
'bus', | |
'car', | |
'cat', | |
'chair', | |
'cow', | |
'diningtable', | |
'dog', | |
'horse', | |
'motorbike', | |
'person', | |
'pottedplant', | |
'sheep', | |
'sofa', | |
'train', | |
'tvmonitor'] | |
class TargetTransformers: | |
def __init__(self, classes: List[str], size=224): | |
self.size = size | |
self.classes = classes | |
def __call__(self, target): | |
target = target['annotation'] | |
w, h = int(target['size']['width']), int(target['size']['height']), | |
boxes = [(int(obj['bndbox']['xmin']) / w * self.size, | |
int(obj['bndbox']['ymin']) / h * self.size, | |
int(obj['bndbox']['xmax']) / w * self.size, | |
int(obj['bndbox']['ymax']) / h * self.size,) for obj in target['object']] | |
labels = [self.classes.index(obj['name']) for obj in target['object']] | |
return {'boxes': torch.FloatTensor(boxes), 'labels': torch.LongTensor(labels)} | |
ds = VOCDetection('/data/dataset', download=False, target_transform=TargetTransformers(classes=CLASSES, size=224), | |
transform=Compose([Resize(224), ToTensor()])) | |
model = ssd300_vgg16(pretrained=True) | |
model.train() | |
x, y = ds[0] | |
print(model([x], [y])) | |
class MySSDCriterion: | |
def __call__(self, output, target): | |
return output['bbox_regression'] + output['classification'] | |
class LocalizationWrapper(ModelWrapper): | |
def train_on_batch( | |
self, data, target, optimizer, cuda=False, regularizer: Optional[Callable] = None | |
): | |
output = self.model(data, target) | |
# Below should be as regular | |
loss = self.criterion(output, target) | |
loss.backward() | |
optimizer.step() | |
return loss | |
optim = Adam(params=model.parameters(), lr=0.001) | |
wrapper = LocalizationWrapper(model=model, criterion=MySSDCriterion(), replicate_in_memory=False) | |
print(wrapper.train_on_batch([x], [y], optimizer=optim, cuda=False)) | |
wrapper.eval() | |
print(wrapper.predict_on_batch([x], cuda=False, iterations=2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment