Skip to content

Instantly share code, notes, and snippets.

Last active April 17, 2024 21:33
Show Gist options
  • Save crowsonkb/e62e9f685da9c185233f66de754f05ca to your computer and use it in GitHub Desktop.
Save crowsonkb/e62e9f685da9c185233f66de754f05ca to your computer and use it in GitHub Desktop.
Grouped linear layer using
"""Grouped linear layer using"""
from dataclasses import dataclass
import warnings
import torch
from torch import nn
import grouped_gemm
_gmm_kernel = torch.compiler.disable(grouped_gemm.ops.gmm)
except ImportError:
warnings.warn("grouped_gemm not available, falling back to PyTorch implementation.")
_gmm_kernel = None
def gmm_pytorch(a, b, batch_sizes, trans_b=False):
"""Grouped matrix multiplication using PyTorch."""
if a.ndim != 2:
raise ValueError("a must be a 2D tensor")
if b.ndim != 3:
raise ValueError("b must be a 3D tensor")
if batch_sizes.ndim != 1:
raise ValueError("batch_sizes must be a 1D tensor")
if b.shape[0] != batch_sizes.shape[0]:
raise ValueError("b and batch_sizes must have the same number of groups")
a_split = torch.split(a, batch_sizes.tolist())
b_split = torch.unbind(b.mT if trans_b else b)
c = [a_part @ b_part for a_part, b_part in zip(a_split, b_split)]
def gmm(a, b, batch_sizes, trans_b=False):
"""Grouped matrix multiplication."""
device_ok = a.device.type == "cuda" and b.device.type == "cuda"
can_cast = torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16
cast_not_needed = a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16
if _gmm_kernel is not None and device_ok and (can_cast or cast_not_needed):
return _gmm_kernel(a.bfloat16(), b.bfloat16(), batch_sizes.cpu(), trans_b)
return gmm_pytorch(a, b, batch_sizes, trans_b)
class GroupInfo:
"""Group information."""
shape: torch.Size
ids_sorted: torch.Tensor
ids_indices: torch.Tensor
batch_sizes: torch.Tensor
def group(x, ids, n_groups):
"""Group a tensor by group IDs.
x: The input tensor.
ids: The group IDs.
n_groups: The number of groups.
x: The grouped tensor.
info: The group information.
if x.shape[:-1] != ids.shape:
raise ValueError(
f"shape mismatch: x.shape[:-1] is {tuple(x.shape[:-1])}, ids.shape is {tuple(ids.shape)}"
shape = ids.shape
x = x.flatten(0, -2)
ids = ids.flatten()
ids_sorted, ids_indices = torch.sort(ids, stable=True)
batch_sizes = torch.bincount(ids_sorted, minlength=n_groups).cpu()
return x[ids_indices], GroupInfo(shape, ids_sorted, ids_indices, batch_sizes)
def ungroup(x, info):
"""Ungroup a tensor.
x: The grouped tensor.
info: The group information.
The ungrouped tensor.
return torch.empty_like(x).index_put_((info.ids_indices,), x).view(*info.shape, x.shape[-1])
class GroupedLinear(nn.Module):
"""Grouped linear layer."""
def __init__(self, in_features, out_features, n_groups, bias=True):
self.in_features = in_features
self.out_features = out_features
self.n_groups = n_groups
self.weight = nn.Parameter(torch.empty(n_groups, out_features, in_features))
self.bias = nn.Parameter(torch.empty(n_groups, out_features)) if bias else None
bound = in_features**-0.5
nn.init.uniform_(self.weight, -bound, bound)
if bias:
nn.init.uniform_(self.bias, -bound, bound)
def extra_repr(self):
return f"in_features={self.in_features}, out_features={self.out_features}, n_groups={self.n_groups}, bias={self.bias is not None}"
def forward(self, x, info):
x = gmm(x, self.weight, info.batch_sizes, trans_b=True)
if self.bias is not None:
x = x +[info.ids_sorted]
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment