Last active
September 1, 2023 16:02
-
-
Save lebrice/a0250cacae5fba613fc5282d905ca674 to your computer and use it in GitHub Desktop.
Multi-task layers (that can be split into layers for each task)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import copy | |
import functools | |
import math | |
from collections import OrderedDict | |
from typing import Sequence | |
import torch | |
from torch import Tensor, nn | |
class MultiTaskLinear(nn.Module): | |
"""Acts like a list of nn.Linear layers, one per task, but using a fused forward pass. | |
Inputs should have shape [batch_size, n_tasks, in_features]. | |
Outputs have shape [batch_size, n_tasks, out_features]. | |
NOTE: If we wanted to have the same data being fed to every expert, the input would need to be | |
"repeated" along the task (1st) dimension. | |
""" | |
def __init__( | |
self, | |
n_tasks: int, | |
in_features: int, | |
out_features: int, | |
use_bias: bool = True, | |
dtype: torch.dtype = torch.float32, | |
) -> None: | |
super().__init__() | |
self.n_tasks = n_tasks | |
self.in_features = in_features | |
self.out_features = out_features | |
self.use_bias = use_bias | |
self.weight = nn.Parameter(torch.empty(n_tasks, in_features, out_features, dtype=dtype)) | |
self.register_parameter("bias", None) | |
if self.use_bias: | |
self.bias = nn.Parameter(torch.empty(n_tasks, out_features, dtype=dtype)) | |
self.reset_parameters() | |
def reset_parameters(self) -> None: | |
# Taken from torch.nn.Linear.reset_parameters | |
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with | |
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see | |
# https://github.com/pytorch/pytorch/issues/57109 | |
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
if self.bias is not None: | |
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) | |
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
torch.nn.init.uniform_(self.bias, -bound, bound) | |
def forward(self, input: Tensor) -> Tensor: | |
# input has shape [B, n_tasks, in_features] | |
# Weight has shape [n_tasks, in_features, out_features] | |
assert input.shape[-2:] == (self.n_tasks, self.in_features) | |
out = torch.einsum("bti,tio->bto", input, self.weight) | |
# out = input @ self.weight | |
if self.use_bias: | |
return out + self.bias | |
return out | |
@functools.singledispatch | |
def split_for_each_task(multi_task_module: nn.Module, n_tasks: int) -> Sequence[nn.Module]: | |
"""Splits a multi-task model into one model for each task.""" | |
raise NotImplementedError( | |
f"Don't know how to split modules of type {type(multi_task_module)} such that we get " | |
f"{n_tasks} versions, one per task. Register a handler function for this type of module " | |
f"(as is done below)." | |
) | |
@split_for_each_task.register(nn.ReLU) | |
def use_same_class_for_each_task(multi_task_module: nn.Module, n_tasks: int) -> list[nn.Module]: | |
return [copy.deepcopy(multi_task_module) for _ in range(n_tasks)] | |
@split_for_each_task.register(nn.Sequential) | |
def split_sequential(multi_task_module: nn.Sequential, n_tasks: int) -> list[nn.Sequential]: | |
splits: dict[str, Sequence[nn.Module]] = { | |
key: split_for_each_task(module, n_tasks) | |
for key, module in multi_task_module.named_children() | |
} | |
return [ | |
nn.Sequential( | |
OrderedDict( | |
{key: layers_for_each_task[i] for key, layers_for_each_task in splits.items()} | |
) | |
) | |
for i in range(n_tasks) | |
] | |
@split_for_each_task.register(MultiTaskLinear) | |
def split_linear( | |
multi_task_module: MultiTaskLinear, n_tasks: int | None = None | |
) -> list[nn.Linear]: | |
layers: list[nn.Linear] = [] | |
# [t, i, o] -> t * [o, i] | |
weights_for_each_task = multi_task_module.weight.transpose(1, 2).unbind(0) | |
if multi_task_module.use_bias: | |
biases_for_each_task = multi_task_module.bias.unbind(0) | |
else: | |
biases_for_each_task = None | |
for task_id in range(n_tasks or multi_task_module.n_tasks): | |
layer = nn.Linear( | |
in_features=multi_task_module.in_features, | |
out_features=multi_task_module.out_features, | |
bias=multi_task_module.use_bias, | |
) | |
layer.weight = nn.Parameter(weights_for_each_task[task_id].clone()) | |
if biases_for_each_task: | |
layer.bias = nn.Parameter(biases_for_each_task[task_id].clone()) | |
layers.append(layer) | |
return layers | |
batch_size = 5 | |
n_tasks = 10 | |
x_dim = 123 | |
n_out_features = 1 | |
hidden_dims = 3 | |
n_features = [x_dim, hidden_dims, n_out_features] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
x = torch.randn((batch_size, n_tasks, x_dim), requires_grad=True, device=device) | |
multi_task_model = nn.Sequential( | |
MultiTaskLinear(n_tasks=n_tasks, in_features=x_dim, out_features=hidden_dims), | |
nn.ReLU(), | |
MultiTaskLinear(n_tasks=n_tasks, in_features=hidden_dims, out_features=n_out_features), | |
).to(device) | |
print("Fused multi-task model:") | |
print(multi_task_model) | |
models_for_each_task = split_for_each_task(multi_task_model, n_tasks) | |
print("Model for a single task:") | |
print(models_for_each_task[0]) | |
fused_task_outputs = multi_task_model(x) | |
inputs_for_each_task = x.unbind(1) | |
outputs_for_each_task = [ | |
model(inputs_for_each_task[task_id]) | |
for task_id, (model, input) in enumerate(zip(models_for_each_task, inputs_for_each_task)) | |
] | |
stacked_task_outputs = torch.stack(outputs_for_each_task, dim=1) | |
# The output of the fused layer is the same as the stacked outputs of the individual layers for | |
# each task. | |
torch.testing.assert_close(stacked_task_outputs, fused_task_outputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment