Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Created April 9, 2022 16:42
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save krsnewwave/fe6e2d87dcaf244eea36c90fc1485ec6 to your computer and use it in GitHub Desktop.
PyTorch model for multitask learning
# following https://towardsdatascience.com/multilabel-classification-with-pytorch-in-5-minutes-a4fa8993cbc7
class LightningResNetMultiLabel(pl.LightningModule):
def __init__(self, net, n_period, n_artists, criterion = F.cross_entropy, optimizer = None, scheduler = None, dropout_p = 0., lr=0.001, freeze_net=False):
super().__init__()
self.net = net
self.feature_extractor = nn.Sequential(*(list(self.net.children())[:-1]))
if freeze_net:
for param in self.net.parameters():
param.requires_grad = False
num_ftrs = net.fc.in_features
self.period_fc = nn.Sequential(
nn.Dropout(p=dropout_p),
nn.Linear(in_features=num_ftrs, out_features=n_period)
)
self.artist_fc = nn.Sequential(
nn.Dropout(p=dropout_p),
nn.Linear(in_features=num_ftrs, out_features=n_artists)
)
self.loss_func = criterion
self.optimizer = optimizer
self.scheduler = scheduler
self.learning_rate = lr
def criterion(self, loss_func, outputs, inputs):
losses = 0
for i, key in enumerate(outputs):
losses += loss_func(outputs[key], inputs[f'{key}_label'])
return losses
def forward(self, x):
x = self.feature_extractor(x)
x = torch.flatten(x, 1)
return {
'period': self.period_fc(x),
'artist': self.artist_fc(x),
}
def _shared_eval_step(self, batch, batch_idx):
images = batch["image"]
period_labels = batch["period_label"]
artist_labels = batch["artist_label"]
out = self(images)
out_period = out["period"]
out_artist = out["artist"]
loss = self.criterion(self.loss_func, out, batch)
period_accu = self.accuracy(out_period, period_labels)
artist_accu = self.accuracy(out_artist, artist_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment