Skip to content

Instantly share code, notes, and snippets.

@tchaton
Last active December 20, 2021 12:29
Show Gist options
  • Save tchaton/bf2520c228e8c2c3f9ea40ab700806cf to your computer and use it in GitHub Desktop.
Save tchaton/bf2520c228e8c2c3f9ea40ab700806cf to your computer and use it in GitHub Desktop.
baal_model.py
from torch import nn
from flash.image import ImageClassifier
from flash.core.classification import Logits
from functools import partial
head = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, dm.num_classes), # define before
)
model = ImageClassifier(
num_classes=dm.num_classes,
head=head,
backbone="vgg16",
pretrained=True,
loss_fn=nn.CrossEntropyLoss(),
optimizer=partial(
torch.optim.SGD, momentum=0.9, weight_decay=5e-4, lr=0.001),
# Note the serializer to `Logits` to be
# able to estimate uncertainty.
output=LogitsOutput(),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment