Created
October 23, 2023 05:30
-
-
Save lzqlzzq/8708389a3daa5695e6ecc96aedd3b20b to your computer and use it in GitHub Desktop.
A simple multi-head classifier implementation (Pytorch).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
* A simple multi-head classifier implementation. | |
+-----------------+ | |
+--------------+ | logits_dict | | |
| hidden | +-----------------+ | |
+------+-------+ ^ | |
| | | |
v +--------------------+--------------------+ | |
+--------------+ | | | | |
| joint_layer | +---------------+ +---------------+ +---------------+ | |
+------+-------+ | output_head_1 | | output_head_2 | | output_head_n | | |
| +---------------+ +---------------+ +---------------+ | |
v ^ ^ ^ | |
+--------------+ | | | | |
| joint_hidden |--------------------+--------------------+--------------------+ | |
+--------------+ | |
""" | |
from collections import OrderedDict | |
from typing import Tuple | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Transpose(nn.Module): | |
def __init__(self, *args): | |
super().__init__() | |
self.args = args | |
def forward(self, x): | |
return x.transpose(*self.args) | |
class MultiHeadClassifier(nn.Module): | |
def __init__(self, | |
hidden_size: int, | |
joint_size: int, | |
classes, #: OrderedDict[str, int], # ('class_name', (class_num, middle_size)) | |
act_fn: nn.Module = nn.ReLU, | |
bias: bool = False, | |
dropout: float = .5): | |
super().__init__() | |
self.classes_names = list(classes.keys()) | |
self.joint_layer = nn.Sequential(nn.Linear(hidden_size, joint_size, bias=bias), | |
act_fn(), | |
Transpose(-1, -2), | |
nn.BatchNorm1d(joint_size), | |
Transpose(-1, -2), | |
nn.Dropout(dropout)) | |
self.output_heads = nn.ModuleList() | |
for class_name, (class_num, middle_size) in classes.items(): | |
self.output_heads.append( | |
nn.Sequential(OrderedDict([ | |
(f'{class_name}_fc', nn.Linear(joint_size, | |
middle_size, bias=bias)), | |
(f'{class_name}_act', act_fn()), | |
(f'{class_name}_tp1', Transpose(-1, -2)), | |
(f'{class_name}_bn', nn.BatchNorm1d(middle_size)), | |
(f'{class_name}_tp2', Transpose(-1, -2)), | |
(f'{class_name}_fc2', nn.Linear(middle_size, | |
class_num, bias=bias)), | |
])) | |
) | |
def forward(self, hidden): | |
joint_hidden = self.joint_layer(hidden) | |
logits = [h(joint_hidden) for h in self.output_heads] | |
return dict(zip(self.classes_names, logits)) | |
def loss(self, | |
logits, | |
targets, | |
reduction = 'mean', | |
ignore_index: int = -100): | |
losses = {c: F.cross_entropy(logits[c].transpose(-1, -2), | |
targets[c], | |
ignore_index=ignore_index, | |
reduction=reduction) for c in self.classes_names} | |
losses['total_loss'] = torch.stack(list(losses.values())) | |
losses['total_loss'] = losses['total_loss'].mean() if reduction == 'mean' else losses['total_loss'].sum() | |
return losses | |
if __name__ == '__main__': | |
CLASSES = OrderedDict([ | |
('class_a', (10, 16)), | |
('class_b', (100, 192)), | |
('class_c', (200, 512))]) | |
BATCH_SIZE = 2 | |
SAMPLE_NUM = 4 | |
HIDDEN_SIZE = 512 | |
classifier = MultiHeadClassifier(hidden_size=HIDDEN_SIZE, | |
joint_size=1024, | |
classes=CLASSES) | |
dummy_hiddens = torch.randn((BATCH_SIZE, SAMPLE_NUM, HIDDEN_SIZE)) | |
logits = classifier(dummy_hiddens) | |
print('logits:') | |
for k, v in logits.items(): | |
print(' ', k, 'shape:', v.shape) | |
dummy_targets = {name: torch.randint(0, cls_num, (BATCH_SIZE, SAMPLE_NUM)) \ | |
for (name, (cls_num, mid_size)) in CLASSES.items()} | |
losses = classifier.loss(logits, dummy_targets, ignore_index=0) | |
print("losses:") | |
for k, v in losses.items(): | |
print(' ', k+':', v.item()) | |
# You can just backward the total_loss! | |
losses['total_loss'].backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment