Skip to content

Instantly share code, notes, and snippets.

@yulkang
Last active February 7, 2024 13:49
Show Gist options
  • Save yulkang/4a597bcc5e9ccf8c7291f8ecb776382d to your computer and use it in GitHub Desktop.
Save yulkang/4a597bcc5e9ccf8c7291f8ecb776382d to your computer and use it in GitHub Desktop.
Kronecker Product in PyTorch - with batch dimensions broadcast
"""A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk"""
import torch
def kron(a, b):
"""
Kronecker product of matrices a and b with leading batch dimensions.
Batch dimensions are broadcast. The number of them mush
:type a: torch.Tensor
:type b: torch.Tensor
:rtype: torch.Tensor
"""
siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
siz0 = res.shape[:-4]
return res.reshape(siz0 + siz1)
@yulkang
Copy link
Author

yulkang commented Jun 17, 2019

In response to : https://discuss.pytorch.org/t/kronecker-product/3919/11
Please let me know if you find any issues!

@yulkang
Copy link
Author

yulkang commented Jun 17, 2019

Note: this function is included in numpytorch.py, which is a part of pylabyk library I am maintaining.

@RaulPPelaez
Copy link

This function as it is is not compatible with torch.compile. But a simple modification fixes it:

    siz1 = torch.Size([a.shape[-2] * b.shape[-2], a.shape[-1] * b.shape[-1]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment