Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Last active April 28, 2024 18:14
Show Gist options
  • Save ahwillia/f65bc70cb30206d4eadec857b98c4065 to your computer and use it in GitHub Desktop.
Save ahwillia/f65bc70cb30206d4eadec857b98c4065 to your computer and use it in GitHub Desktop.
Efficient computation of a Kronecker - vector product (with multiple matrices).
import numpy as np
import numpy.random as npr
from functools import reduce
# Goal
# ----
# Compute (As[0] kron As[1] kron ... As[-1]) @ v
# ==== HELPER FUNCTIONS ==== #
def unfold(tens, mode, dims):
"""
Unfolds tensor into matrix.
Parameters
----------
tens : ndarray, tensor with shape == dims
mode : int, which axis to move to the front
dims : list, holds tensor shape
Returns
-------
matrix : ndarray, shape (dims[mode], prod(dims[/mode]))
"""
if mode == 0:
return tens.reshape(dims[0], -1)
else:
return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1)
def refold(vec, mode, dims):
"""
Refolds vector into tensor.
Parameters
----------
vec : ndarray, tensor with len == prod(dims)
mode : int, which axis was unfolded along.
dims : list, holds tensor shape
Returns
-------
tens : ndarray, tensor with shape == dims
"""
if mode == 0:
return vec.reshape(dims)
else:
# Reshape and then move dims[mode] back to its
# appropriate spot (undoing the `unfold` operation).
tens = vec.reshape(
[dims[mode]] +
[d for m, d in enumerate(dims) if m != mode]
)
return np.moveaxis(tens, 0, mode)
# ==== KRON-VEC PRODUCT COMPUTATIONS ==== #
def kron_vec_prod(As, v):
"""
Computes matrix-vector multiplication between
matrix kron(As[0], As[1], ..., As[N]) and vector
v without forming the full kronecker product.
"""
dims = [A.shape[0] for A in As]
vt = v.reshape(dims)
for i, A in enumerate(As):
vt = refold(A @ unfold(vt, i, dims), i, dims)
return vt.ravel()
def kron_brute_force(As, v):
"""
Computes kron-matrix times vector by brute
force (instantiates the full kron product).
"""
return reduce(np.kron, As) @ v
# Quick demonstration.
if __name__ == "__main__":
# Create random problem.
_dims = [3, 3, 3, 3, 3, 3, 3, 3]
As = [npr.randn(d, d) for d in _dims]
v = npr.randn(np.prod(_dims))
# Test accuracy.
actual = kron_vec_prod(As, v)
expected = kron_brute_force(As, v)
print(np.linalg.norm(actual - expected))
@ahwillia
Copy link
Author

ahwillia commented Nov 6, 2019

Speed comparison for dims = [3, 3, 3, 3, 3, 3, 3, 3].

%timeit kron_brute_force(As, v)                                                                                         
946 ms ± 17.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit kron_vec_prod(As, v)                                                                                            
299 µs ± 2.86 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Speed comparison for dims = [20, 20, 20].

%timeit kron_brute_force(As, v)                                                                                         
1.07 s ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit kron_vec_prod(As, v)                                                                                            
105 µs ± 845 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

@ahwillia
Copy link
Author

ahwillia commented Nov 6, 2019

Some back-of-the-envelope calculations and comments on computational complexity. Suppose M square matrices of dimension N x N are kronecker-multiplied. The brute-force approach constructs a N^M x N^M matrix which is multiplied by a vector with N^M. In total, this results in O(N^{2 * M}) flops.

The code above implements a faster approach, which involves M sequential matrix multiplies. Each matrix multiply is between an N x N matrix and a N x N^{M-1} matrix, which (assuming no Strassen-type algorithms for matrix-multiply) takes O(N^{M+1}) flops. Since this is repeated M times, the total computational cost is O(M * N^{M+1}). This is often substantially faster as it avoids the factor of 2 in the exponent.

@weiT1993
Copy link

weiT1993 commented Sep 3, 2021

Very helpful codes! But does this only work for square matrices?

@ahwillia
Copy link
Author

ahwillia commented Sep 3, 2021

Yes I think this only works for square matrices. I'm not immediately sure how to extend it to the non-square case.

@renatomello
Copy link

renatomello commented Nov 10, 2021

What if one just wanted the kron result without the subsequent vector product?

@sambroy
Copy link

sambroy commented Aug 14, 2023

For extending to the non-square case, all we need is a minor change to the kron_vec_prod method:

def kron_vec_prod(As, v):
    """
    Computes matrix-vector multiplication between
    matrix kron(As[0], As[1], ..., As[N]) and vector
    v without forming the full kronecker product.
    """
    dims = [A.shape[1] for A in As]
    vt = v.reshape(dims)
    dims_in = dims
    for i, A in enumerate(As):
        # change the ith entry of dims to A.shape[0]
        dims_fin = np.copy(dims_in)
        dims_fin[i] = A.shape[0]
        vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
        dims_in = np.copy(dims_fin)
    return vt.ravel()

The modified code is as follows:

import numpy as np
import numpy.random as npr
from functools import reduce

# Goal
# ----
# Compute (As[0] kron As[1] kron ... As[-1]) @ v

# ==== HELPER FUNCTIONS ==== #

def unfold(tens, mode, dims):
    """
    Unfolds tensor into matrix.

    Parameters
    ----------
    tens : ndarray, tensor with shape == dims
    mode : int, which axis to move to the front
    dims : list, holds tensor shape

    Returns
    -------
    matrix : ndarray, shape (dims[mode], prod(dims[/mode]))
    """
    if mode == 0:
        return tens.reshape(dims[0], -1)
    else:
        return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1)


def refold(vec, mode, dims):
    """
    Refolds vector into tensor.

    Parameters
    ----------
    vec : ndarray, tensor with len == prod(dims)
    mode : int, which axis was unfolded along.
    dims : list, holds tensor shape

    Returns
    -------
    tens : ndarray, tensor with shape == dims
    """
    if mode == 0:
        return vec.reshape(dims)
    else:
        # Reshape and then move dims[mode] back to its
        # appropriate spot (undoing the `unfold` operation).
        tens = vec.reshape(
            [dims[mode]] +
            [d for m, d in enumerate(dims) if m != mode]
        )
        return np.moveaxis(tens, 0, mode)

# ==== KRON-VEC PRODUCT COMPUTATIONS ==== #

def kron_vec_prod(As, v):
    """
    Computes matrix-vector multiplication between
    matrix kron(As[0], As[1], ..., As[N]) and vector
    v without forming the full kronecker product.
    """
    dims = [A.shape[1] for A in As]
    vt = v.reshape(dims)
    dims_in = dims
    for i, A in enumerate(As):
        # change the ith entry of dims to A.shape[0]
        dims_fin = np.copy(dims_in)
        dims_fin[i] = A.shape[0]
        vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
        dims_in = np.copy(dims_fin)
    return vt.ravel()


def kron_brute_force(As, v):
    """
    Computes kron-matrix times vector by brute
    force (instantiates the full kron product).
    """
    return reduce(np.kron, As) @ v


# Quick demonstration.
if __name__ == "__main__":

    # Create random problem.
    _yaxes = [2, 3, 4]
    _xaxes = [1, 2, 1]
    # As = [np.ones((x,y)) for (x, y) in zip(_xaxes, _yaxes)]
    As = [np.random.rand(x, y) for (x, y) in zip(_xaxes, _yaxes)]

    v = np.ones((np.prod(_yaxes), ))

    # Test accuracy.
    actual = kron_vec_prod(As, v)
    expected = kron_brute_force(As, v)
    print(np.linalg.norm(actual - expected))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment