Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active April 17, 2024 21:33
Show Gist options
  • Save crowsonkb/e62e9f685da9c185233f66de754f05ca to your computer and use it in GitHub Desktop.
Save crowsonkb/e62e9f685da9c185233f66de754f05ca to your computer and use it in GitHub Desktop.
Grouped linear layer using https://github.com/tgale96/grouped_gemm.
"""Grouped linear layer using https://github.com/tgale96/grouped_gemm."""
from dataclasses import dataclass
import warnings
import torch
from torch import nn
try:
import grouped_gemm
_gmm_kernel = torch.compiler.disable(grouped_gemm.ops.gmm)
except ImportError:
warnings.warn("grouped_gemm not available, falling back to PyTorch implementation.")
_gmm_kernel = None
@torch.compiler.disable
def gmm_pytorch(a, b, batch_sizes, trans_b=False):
"""Grouped matrix multiplication using PyTorch."""
if a.ndim != 2:
raise ValueError("a must be a 2D tensor")
if b.ndim != 3:
raise ValueError("b must be a 3D tensor")
if batch_sizes.ndim != 1:
raise ValueError("batch_sizes must be a 1D tensor")
if b.shape[0] != batch_sizes.shape[0]:
raise ValueError("b and batch_sizes must have the same number of groups")
a_split = torch.split(a, batch_sizes.tolist())
b_split = torch.unbind(b.mT if trans_b else b)
c = [a_part @ b_part for a_part, b_part in zip(a_split, b_split)]
return torch.cat(c)
def gmm(a, b, batch_sizes, trans_b=False):
"""Grouped matrix multiplication."""
device_ok = a.device.type == "cuda" and b.device.type == "cuda"
can_cast = torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16
cast_not_needed = a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16
if _gmm_kernel is not None and device_ok and (can_cast or cast_not_needed):
return _gmm_kernel(a.bfloat16(), b.bfloat16(), batch_sizes.cpu(), trans_b)
return gmm_pytorch(a, b, batch_sizes, trans_b)
@dataclass
class GroupInfo:
"""Group information."""
shape: torch.Size
ids_sorted: torch.Tensor
ids_indices: torch.Tensor
batch_sizes: torch.Tensor
def group(x, ids, n_groups):
"""Group a tensor by group IDs.
Args:
x: The input tensor.
ids: The group IDs.
n_groups: The number of groups.
Returns:
x: The grouped tensor.
info: The group information.
"""
if x.shape[:-1] != ids.shape:
raise ValueError(
f"shape mismatch: x.shape[:-1] is {tuple(x.shape[:-1])}, ids.shape is {tuple(ids.shape)}"
)
shape = ids.shape
x = x.flatten(0, -2)
ids = ids.flatten()
ids_sorted, ids_indices = torch.sort(ids, stable=True)
batch_sizes = torch.bincount(ids_sorted, minlength=n_groups).cpu()
return x[ids_indices], GroupInfo(shape, ids_sorted, ids_indices, batch_sizes)
def ungroup(x, info):
"""Ungroup a tensor.
Args:
x: The grouped tensor.
info: The group information.
Returns:
The ungrouped tensor.
"""
return torch.empty_like(x).index_put_((info.ids_indices,), x).view(*info.shape, x.shape[-1])
class GroupedLinear(nn.Module):
"""Grouped linear layer."""
def __init__(self, in_features, out_features, n_groups, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.n_groups = n_groups
self.weight = nn.Parameter(torch.empty(n_groups, out_features, in_features))
self.bias = nn.Parameter(torch.empty(n_groups, out_features)) if bias else None
bound = in_features**-0.5
nn.init.uniform_(self.weight, -bound, bound)
if bias:
nn.init.uniform_(self.bias, -bound, bound)
def extra_repr(self):
return f"in_features={self.in_features}, out_features={self.out_features}, n_groups={self.n_groups}, bias={self.bias is not None}"
def forward(self, x, info):
x = gmm(x, self.weight, info.batch_sizes, trans_b=True)
if self.bias is not None:
x = x + self.bias.to(x.dtype)[info.ids_sorted]
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment