Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Created October 23, 2023 05:30
Show Gist options
  • Save lzqlzzq/8708389a3daa5695e6ecc96aedd3b20b to your computer and use it in GitHub Desktop.
Save lzqlzzq/8708389a3daa5695e6ecc96aedd3b20b to your computer and use it in GitHub Desktop.
A simple multi-head classifier implementation (Pytorch).
"""
* A simple multi-head classifier implementation.
+-----------------+
+--------------+ | logits_dict |
| hidden | +-----------------+
+------+-------+ ^
| |
v +--------------------+--------------------+
+--------------+ | | |
| joint_layer | +---------------+ +---------------+ +---------------+
+------+-------+ | output_head_1 | | output_head_2 | | output_head_n |
| +---------------+ +---------------+ +---------------+
v ^ ^ ^
+--------------+ | | |
| joint_hidden |--------------------+--------------------+--------------------+
+--------------+
"""
from collections import OrderedDict
from typing import Tuple
import torch
from torch import nn
from torch.nn import functional as F
class Transpose(nn.Module):
def __init__(self, *args):
super().__init__()
self.args = args
def forward(self, x):
return x.transpose(*self.args)
class MultiHeadClassifier(nn.Module):
def __init__(self,
hidden_size: int,
joint_size: int,
classes, #: OrderedDict[str, int], # ('class_name', (class_num, middle_size))
act_fn: nn.Module = nn.ReLU,
bias: bool = False,
dropout: float = .5):
super().__init__()
self.classes_names = list(classes.keys())
self.joint_layer = nn.Sequential(nn.Linear(hidden_size, joint_size, bias=bias),
act_fn(),
Transpose(-1, -2),
nn.BatchNorm1d(joint_size),
Transpose(-1, -2),
nn.Dropout(dropout))
self.output_heads = nn.ModuleList()
for class_name, (class_num, middle_size) in classes.items():
self.output_heads.append(
nn.Sequential(OrderedDict([
(f'{class_name}_fc', nn.Linear(joint_size,
middle_size, bias=bias)),
(f'{class_name}_act', act_fn()),
(f'{class_name}_tp1', Transpose(-1, -2)),
(f'{class_name}_bn', nn.BatchNorm1d(middle_size)),
(f'{class_name}_tp2', Transpose(-1, -2)),
(f'{class_name}_fc2', nn.Linear(middle_size,
class_num, bias=bias)),
]))
)
def forward(self, hidden):
joint_hidden = self.joint_layer(hidden)
logits = [h(joint_hidden) for h in self.output_heads]
return dict(zip(self.classes_names, logits))
def loss(self,
logits,
targets,
reduction = 'mean',
ignore_index: int = -100):
losses = {c: F.cross_entropy(logits[c].transpose(-1, -2),
targets[c],
ignore_index=ignore_index,
reduction=reduction) for c in self.classes_names}
losses['total_loss'] = torch.stack(list(losses.values()))
losses['total_loss'] = losses['total_loss'].mean() if reduction == 'mean' else losses['total_loss'].sum()
return losses
if __name__ == '__main__':
CLASSES = OrderedDict([
('class_a', (10, 16)),
('class_b', (100, 192)),
('class_c', (200, 512))])
BATCH_SIZE = 2
SAMPLE_NUM = 4
HIDDEN_SIZE = 512
classifier = MultiHeadClassifier(hidden_size=HIDDEN_SIZE,
joint_size=1024,
classes=CLASSES)
dummy_hiddens = torch.randn((BATCH_SIZE, SAMPLE_NUM, HIDDEN_SIZE))
logits = classifier(dummy_hiddens)
print('logits:')
for k, v in logits.items():
print(' ', k, 'shape:', v.shape)
dummy_targets = {name: torch.randint(0, cls_num, (BATCH_SIZE, SAMPLE_NUM)) \
for (name, (cls_num, mid_size)) in CLASSES.items()}
losses = classifier.loss(logits, dummy_targets, ignore_index=0)
print("losses:")
for k, v in losses.items():
print(' ', k+':', v.item())
# You can just backward the total_loss!
losses['total_loss'].backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment