Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
PyTorch model for multitask learning
# following
class LightningResNetMultiLabel(pl.LightningModule):
def __init__(self, net, n_period, n_artists, criterion = F.cross_entropy, optimizer = None, scheduler = None, dropout_p = 0., lr=0.001, freeze_net=False):
super().__init__() = net
self.feature_extractor = nn.Sequential(*(list([:-1]))
if freeze_net:
for param in
param.requires_grad = False
num_ftrs = net.fc.in_features
self.period_fc = nn.Sequential(
nn.Linear(in_features=num_ftrs, out_features=n_period)
self.artist_fc = nn.Sequential(
nn.Linear(in_features=num_ftrs, out_features=n_artists)
self.loss_func = criterion
self.optimizer = optimizer
self.scheduler = scheduler
self.learning_rate = lr
def criterion(self, loss_func, outputs, inputs):
losses = 0
for i, key in enumerate(outputs):
losses += loss_func(outputs[key], inputs[f'{key}_label'])
return losses
def forward(self, x):
x = self.feature_extractor(x)
x = torch.flatten(x, 1)
return {
'period': self.period_fc(x),
'artist': self.artist_fc(x),
def _shared_eval_step(self, batch, batch_idx):
images = batch["image"]
period_labels = batch["period_label"]
artist_labels = batch["artist_label"]
out = self(images)
out_period = out["period"]
out_artist = out["artist"]
loss = self.criterion(self.loss_func, out, batch)
period_accu = self.accuracy(out_period, period_labels)
artist_accu = self.accuracy(out_artist, artist_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment