Skip to content

Instantly share code, notes, and snippets.

@dblalock
Created August 1, 2023 02:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dblalock/6ba6f325363427d8aa4b533eb32f8352 to your computer and use it in GitHub Desktop.
Save dblalock/6ba6f325363427d8aa4b533eb32f8352 to your computer and use it in GitHub Desktop.
block diagonal matmul
from typing import * # for convenience; also, this is valid for typing
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor # shorten type signatures
# def try_compile(f: Callable):
# def try_compile():
try:
do_compile = torch.compile
except AttributeError:
# do_compile = lambda: x: x # noqa
do_compile = torch.jit.script
def _diagonal_flat_idxs(W: Tensor) -> Tensor:
min_dim = min(W.shape[-2], W.shape[-1])
eye = torch.eye(min_dim)
tmp = torch.zeros_like(W)
tmp[..., :min_dim, :min_dim] = eye
return torch.where(tmp.view(-1))[0]
class FlexibleMatmul(torch.autograd.Function):
"""Do {batched, single} matmuls {with, without} {bias, out, existing .grad}.
Also lets you choose whether output should be rowmajor or colmajor."""
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float16)
def forward(ctx,
bias: Optional[Tensor],
X: Tensor,
W: Tensor,
out: Optional[Tensor],
colmajor_output: bool = False,
# bias_grad: Optional[Tensor] = None,
Xgrad: Optional[Tensor] = None,
Wgrad: Optional[Tensor] = None,
X_identity_idxs: Optional[Tensor] = None,
W_identity_idxs: Optional[Tensor] = None,
) -> Tensor:
# TODO just save bias shape, rather than the actual tensor
# ctx.save_for_backward(bias.shape, bias_requires_grad, bias_grad, X, W, Wgrad)
# Wgrad = Wgrad or (W.grad if hasattr(W, 'grad') else None)
# ctx.save_for_backward(bias, X, W, Xgrad, Wgrad)
ctx.save_for_backward(bias, X, W, Xgrad, Wgrad, X_identity_idxs, W_identity_idxs)
# assert Wgrad is not None # TODO rm
# print("------------------------> Is wgrad None? ", Wgrad is None)
# TODO throw informative errors instead
assert X.ndim in (2, 3)
assert W.ndim in (2, 3)
assert out is None or out.ndim in (2, 3)
# add in identity mat if requested
if X_identity_idxs is not None:
X.view(-1)[X_identity_idxs] += 1
if W_identity_idxs is not None:
W.view(-1)[W_identity_idxs] += 1
# reduce regular matmuls to bmms with batch size 1
orig_ndim = max(X.ndim, W.ndim)
X = X.view(1, *X.shape) if X.ndim == 2 else X
W = W.view(1, *W.shape) if W.ndim == 2 else W
out = out.view(1, *out.shape) if out is not None and out.ndim == 2 else out
# now we handle the different cases; this looks grosser than it is
# because we can't transpose None and can't supply None as the bias
# to baddbmm; but conceptually, we're always just doing either:
# ret = torch.baddbmm(bias, X, W, out=out) # rowmajor
# or
# ret = torch.baddbmm(bias.T, W.T, X.T, out=out).T # colmajor
if colmajor_output:
Xt = X.transpose(-2, -1)
Wt = W.transpose(-2, -1)
if out is not None:
out = out.reshape(X.shape[0], W.shape[-1], X.shape[-2])
if bias is None:
ret = torch.bmm(Wt, Xt, out=out).transpose(-2, -1)
else:
bias = torch.atleast_2d(bias).transpose(-2, -1)
# print("bmm Xt shape", Xt.shape)
# print("bmm Wt shape", Wt.shape)
# print("bmm biasT shape", bias.shape)
ret = torch.baddbmm(bias, Wt, Xt, out=out).transpose(-2, -1)
else: # rowmajor
if bias is None:
ret = torch.bmm(X, W, out=out)
else:
ret = torch.baddbmm(bias, X, W, out=out)
if orig_ndim == 2:
ret = ret.view(ret.shape[1:])
# undo addition of identity mat
if X_identity_idxs is not None:
X.view(-1)[X_identity_idxs] -= 1
if W_identity_idxs is not None:
W.view(-1)[W_identity_idxs] -= 1
return ret
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dOut: Tensor) -> Tuple[
Tensor, Tensor, Optional[Tensor], None, None, None, None, None]:
bias, X, W, Xgrad, Wgrad, X_identity_idxs, W_identity_idxs = ctx.saved_tensors
dX = None
if ctx.needs_input_grad[2]: # dgrad
# add in identity mat if requested; we do this here instead of
# just passing the idxs into the apply call to avoid having to
# reason about how these idxs get transposed
if W_identity_idxs is not None:
W.view(-1)[W_identity_idxs] += 1
colmajor_out = X.stride()[-2] < X.stride()[-1]
dX = FlexibleMatmul.apply(Xgrad, dOut, W.transpose(-2, -1), Xgrad, colmajor_out)
# return None since we already wrote to Xgrad
dX = dX if Xgrad is None else None
# undo addition of identity mat
if W_identity_idxs is not None:
W.view(-1)[W_identity_idxs] -= 1
dW = None
if ctx.needs_input_grad[1]: # wgrad
# add in identity mat if requested
if X_identity_idxs is not None:
X.view(-1)[X_identity_idxs] += 1
Xt = X.transpose(-2, -1)
dW_strides = Wgrad.stride() if Wgrad is not None else W.stride()
colmajor_out = dW_strides[-2] < dW_strides[-1]
dW = FlexibleMatmul.apply(Wgrad, Xt, dOut, Wgrad, colmajor_out)
# return None since we already wrote to Wgrad
dW = dW if Wgrad is None else None
# undo addition of identity mat
if X_identity_idxs is not None:
X.view(-1)[X_identity_idxs] -= 1
dBias = None
if bias is not None and ctx.needs_input_grad[0]:
if dOut.ndim == 2: # always treat as baddbmm
dOut = dOut.reshape(1, *dOut.shape)
# figure out which dims we need to sum over; we just treat
# bias as a 3d tensor, prepending 1s as needed, and then check
# which dims are 1.
shape_as_3d = (1, 1, 1)[:-bias.ndim] + tuple(bias.shape)
contract_dims = tuple([i for i, dim in enumerate(shape_as_3d) if dim == 1])
dBias = dOut.sum(dim=contract_dims) if contract_dims else dOut
dBias = dBias.view(bias.shape)
return dBias, dX, dW, None, None, None, None, None, None
def flexible_gemm(X: Tensor,
W: Tensor,
bias: Optional[Tensor] = None,
out: Optional[Tensor] = None,
colmajor_output: bool = False,
Xgrad: Optional[Tensor] = None,
Wgrad: Optional[Tensor] = None,
X_identity_idxs: Optional[Tensor] = None,
W_identity_idxs: Optional[Tensor] = None,
f_act: Optional[str] = None,
) -> Tensor:
# if Wgrad passed, backwards accumulates into it directly
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress autograd non-leaf .grad warning
if Xgrad is None and hasattr(X, 'grad'):
Xgrad = X.grad
if Wgrad is None and hasattr(W, 'grad'):
Wgrad = W.grad
# we can't just forward args because "torch.jit.frontend.NotSupportedError:
# Compiled functions can't take variable number of arguments or use
# keyword-only arguments with defaults"
# @do_compile # can't wrap autograd func apply(), so we're SOL for now
def _body(bias, X, W, out, colmajor_output, Xgrad, Wgrad,
X_identity_idxs, W_identity_idxs):
ret = FlexibleMatmul.apply(bias, X, W, out, colmajor_output,
Xgrad, Wgrad,
X_identity_idxs, W_identity_idxs)
if f_act == 'pow2': # TODO other options
ret = 2 ** ret
return ret
return _body(bias, X, W, out, colmajor_output, Xgrad, Wgrad,
X_identity_idxs, W_identity_idxs)
def block_diag_addmm(bias: Optional[Tensor],
X: Tensor,
W: Tensor,
out: Optional[Tensor] = None,
X_identity_idxs: Optional[Tensor] = None,
W_identity_idxs: Optional[Tensor] = None,
f_act: Optional[str] = None,
) -> Tensor:
# num_subspaces = W.shape[0]
# if (X.ndim != 2) or (X.stride()[0] < X.stride()[1]):
# raise NotImplementedError("Only rowmajor X supported")
# print("input X shape", X.shape)
# print("input W shape", W.shape)
# print("input bias shape", bias.shape)
M, K = X.shape
num_subspaces = W.shape[0]
assert K == num_subspaces * W.shape[1]
N = num_subspaces * W.shape[2]
if bias is not None:
# bias is now (nrows or 1) x nsubspaces x out_subspace_len
bias = bias.reshape(-1, num_subspaces, N // num_subspaces)
bias = bias.transpose(0, 1) # nsubspaces x (nrows or 1) x out_subspace_len
# TODO support col vector as bias, not just row vec and full mat
# view X as rowmajor (nrows x num_subspaces x subspace_len); strides
# are still descending
X = X.reshape(M, num_subspaces, K // num_subspaces)
# leading dim needs to be bmm batch dim, so turn X into batch of strided
# rowmajor mats of shape (num_subspaces x nrows x subspace_len)
X = X.transpose(0, 1) # strides now (middle, biggest, smallest)
# output has strides (biggest, smallest, middle)
# output has shape (num_subspaces, nrows, subspace_len)
ret = flexible_gemm(X, W, bias=bias, out=out,
colmajor_output=True,
X_identity_idxs=X_identity_idxs,
W_identity_idxs=W_identity_idxs,
f_act=f_act)
return ret.transpose(1, 0).view(M, N) # nrows x ncols_out
class FlexiLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
add_identity: bool = False,
colmajor_output: bool = False,
# num_subspaces: int = -1):
num_subspaces: int = -1,
f_act: Optional[str] = None):
super().__init__()
self.add_identity = add_identity
self.colmajor_output = colmajor_output
self.num_subspaces = num_subspaces # for block diag linear
self.f_act = f_act
if num_subspaces > 0: # block diag matmul
if in_features % num_subspaces != 0:
raise ValueError(f'Subspace count {num_subspaces} does not ' +
f'evenly divide in_features {in_features}')
if out_features % num_subspaces != 0:
raise ValueError(f'Subspace count {num_subspaces} does not ' +
f'evenly divide out_features {out_features}')
self.weight = torch.nn.Parameter(torch.empty(
num_subspaces,
in_features // num_subspaces,
out_features // num_subspaces,
dtype=dtype,
device=device))
else:
# NOTE: dims are transpose of vanilla linear
self.weight = torch.nn.Parameter(torch.empty(
in_features, out_features, dtype=dtype, device=device))
if bias:
self.bias = torch.nn.Parameter(torch.empty(
out_features, dtype=dtype, device=device))
else:
self.bias = None
if add_identity:
idxs = _diagonal_flat_idxs(self.weight).to(dtype=dtype, device=device)
self.register_buffer('identity_idxs', idxs)
else:
self.identity_idxs = None
def forward(self, X: Tensor, accum: Optional[Tensor] = None) -> Tensor:
if self.bias is not None and accum is not None:
bias = accum + self.bias
elif self.bias is not None and accum is None:
bias = self.bias
elif self.bias is None and accum is not None:
bias = accum
else: # no bias and no accum
bias = None
if self.num_subspaces < 1:
return flexible_gemm(X, self.weight, bias=bias,
colmajor_output=self.colmajor_output,
W_identity_idxs=self.identity_idxs,
f_act=self.f_act)
# XXX blockdiag output is (necessarily) always colmajor, and thus
# ignores our colmajor_output arg
return block_diag_addmm(X=X, W=self.weight, bias=bias, W_identity_idxs=self.identity_idxs, f_act=self.f_act)
if __name__ == '__main__':
M, K, N = 4, 6, 8
num_subspaces = 2
X = torch.randn(M, K, requires_grad=True)
W = torch.randn(num_subspaces, K // num_subspaces, N // num_subspaces, requires_grad=True)
accum = torch.randn(M, N, requires_grad=True)
Y = block_diag_addmm(accum, X, W)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment