Skip to content

Instantly share code, notes, and snippets.

@yulkang
Last active September 18, 2022 12:19
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save yulkang/2e4fc3061b45403f455d7f4c316ab168 to your computer and use it in GitHub Desktop.
Save yulkang/2e4fc3061b45403f455d7f4c316ab168 to your computer and use it in GitHub Desktop.
Block diagonal matrix in PyTorch - vectorized
"""A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk"""
import torch
def block_diag(m):
"""
Make a block diagonal matrix along dim=-3
EXAMPLE:
block_diag(torch.ones(4,3,2))
should give a 12 x 8 matrix with blocks of 3 x 2 ones.
Prepend batch dimensions if needed.
You can also give a list of matrices.
:type m: torch.Tensor, list
:rtype: torch.Tensor
"""
if type(m) is list:
m = torch.cat([m1.unsqueeze(-3) for m1 in m], -3)
d = m.dim()
n = m.shape[-3]
siz0 = m.shape[:-3]
siz1 = m.shape[-2:]
m2 = m.unsqueeze(-2)
eye = attach_dim(torch.eye(n).unsqueeze(-2), d - 3, 1)
return (m2 * eye).reshape(
siz0 + torch.Size(torch.tensor(siz1) * n)
)
def attach_dim(v, n_dim_to_prepend=0, n_dim_to_append=0):
return v.reshape(
torch.Size([1] * n_dim_to_prepend)
+ v.shape
+ torch.Size([1] * n_dim_to_append))
@yulkang
Copy link
Author

yulkang commented Jun 18, 2019

For an up-to-date version, check numpytorch.py in my pylabyk library: https://github.com/yulkang/pylabyk

@yulkang
Copy link
Author

yulkang commented Jun 18, 2019

Example:

>>> block_diag(torch.ones(4,3,2))
tensor([[1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1.]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment