Skip to content

Instantly share code, notes, and snippets.

@lebrice
Last active September 1, 2023 16:02
Show Gist options
  • Save lebrice/a0250cacae5fba613fc5282d905ca674 to your computer and use it in GitHub Desktop.
Save lebrice/a0250cacae5fba613fc5282d905ca674 to your computer and use it in GitHub Desktop.
Multi-task layers (that can be split into layers for each task)
from __future__ import annotations
import copy
import functools
import math
from collections import OrderedDict
from typing import Sequence
import torch
from torch import Tensor, nn
class MultiTaskLinear(nn.Module):
"""Acts like a list of nn.Linear layers, one per task, but using a fused forward pass.
Inputs should have shape [batch_size, n_tasks, in_features].
Outputs have shape [batch_size, n_tasks, out_features].
NOTE: If we wanted to have the same data being fed to every expert, the input would need to be
"repeated" along the task (1st) dimension.
"""
def __init__(
self,
n_tasks: int,
in_features: int,
out_features: int,
use_bias: bool = True,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.n_tasks = n_tasks
self.in_features = in_features
self.out_features = out_features
self.use_bias = use_bias
self.weight = nn.Parameter(torch.empty(n_tasks, in_features, out_features, dtype=dtype))
self.register_parameter("bias", None)
if self.use_bias:
self.bias = nn.Parameter(torch.empty(n_tasks, out_features, dtype=dtype))
self.reset_parameters()
def reset_parameters(self) -> None:
# Taken from torch.nn.Linear.reset_parameters
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
torch.nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
# input has shape [B, n_tasks, in_features]
# Weight has shape [n_tasks, in_features, out_features]
assert input.shape[-2:] == (self.n_tasks, self.in_features)
out = torch.einsum("bti,tio->bto", input, self.weight)
# out = input @ self.weight
if self.use_bias:
return out + self.bias
return out
@functools.singledispatch
def split_for_each_task(multi_task_module: nn.Module, n_tasks: int) -> Sequence[nn.Module]:
"""Splits a multi-task model into one model for each task."""
raise NotImplementedError(
f"Don't know how to split modules of type {type(multi_task_module)} such that we get "
f"{n_tasks} versions, one per task. Register a handler function for this type of module "
f"(as is done below)."
)
@split_for_each_task.register(nn.ReLU)
def use_same_class_for_each_task(multi_task_module: nn.Module, n_tasks: int) -> list[nn.Module]:
return [copy.deepcopy(multi_task_module) for _ in range(n_tasks)]
@split_for_each_task.register(nn.Sequential)
def split_sequential(multi_task_module: nn.Sequential, n_tasks: int) -> list[nn.Sequential]:
splits: dict[str, Sequence[nn.Module]] = {
key: split_for_each_task(module, n_tasks)
for key, module in multi_task_module.named_children()
}
return [
nn.Sequential(
OrderedDict(
{key: layers_for_each_task[i] for key, layers_for_each_task in splits.items()}
)
)
for i in range(n_tasks)
]
@split_for_each_task.register(MultiTaskLinear)
def split_linear(
multi_task_module: MultiTaskLinear, n_tasks: int | None = None
) -> list[nn.Linear]:
layers: list[nn.Linear] = []
# [t, i, o] -> t * [o, i]
weights_for_each_task = multi_task_module.weight.transpose(1, 2).unbind(0)
if multi_task_module.use_bias:
biases_for_each_task = multi_task_module.bias.unbind(0)
else:
biases_for_each_task = None
for task_id in range(n_tasks or multi_task_module.n_tasks):
layer = nn.Linear(
in_features=multi_task_module.in_features,
out_features=multi_task_module.out_features,
bias=multi_task_module.use_bias,
)
layer.weight = nn.Parameter(weights_for_each_task[task_id].clone())
if biases_for_each_task:
layer.bias = nn.Parameter(biases_for_each_task[task_id].clone())
layers.append(layer)
return layers
batch_size = 5
n_tasks = 10
x_dim = 123
n_out_features = 1
hidden_dims = 3
n_features = [x_dim, hidden_dims, n_out_features]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.randn((batch_size, n_tasks, x_dim), requires_grad=True, device=device)
multi_task_model = nn.Sequential(
MultiTaskLinear(n_tasks=n_tasks, in_features=x_dim, out_features=hidden_dims),
nn.ReLU(),
MultiTaskLinear(n_tasks=n_tasks, in_features=hidden_dims, out_features=n_out_features),
).to(device)
print("Fused multi-task model:")
print(multi_task_model)
models_for_each_task = split_for_each_task(multi_task_model, n_tasks)
print("Model for a single task:")
print(models_for_each_task[0])
fused_task_outputs = multi_task_model(x)
inputs_for_each_task = x.unbind(1)
outputs_for_each_task = [
model(inputs_for_each_task[task_id])
for task_id, (model, input) in enumerate(zip(models_for_each_task, inputs_for_each_task))
]
stacked_task_outputs = torch.stack(outputs_for_each_task, dim=1)
# The output of the fused layer is the same as the stacked outputs of the individual layers for
# each task.
torch.testing.assert_close(stacked_task_outputs, fused_task_outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment