Skip to content

Instantly share code, notes, and snippets.

@IvanaGyro
Created August 16, 2023 01:29
Show Gist options
  • Save IvanaGyro/c3a217b992dd856a4222762d3a94a557 to your computer and use it in GitHub Desktop.
Save IvanaGyro/c3a217b992dd856a4222762d3a94a557 to your computer and use it in GitHub Desktop.
Decompose states into matrix product states.
from functools import reduce
import numpy as np
s = np.random.randn(8).reshape(2,2,2)
u, e, v = np.linalg.svd(s.reshape(2, 4), full_matrices=False)
a1 = u
s2 = np.tensordot(np.diag(e), v, axes=(-1, 0))
u, e, v = np.linalg.svd(s2.reshape(4, 2), full_matrices=False)
a2 = u.reshape(2, 2, 2)
v = np.tensordot(np.diag(e), v, axes=(-1, 0))
compose_s = reduce(lambda a, b: np.tensordot(a, b, axes=(-1, 0)), [a1, a2, v])
print('total parameters: ', sum([a.size for a in [a1, a2, v]]))
print(np.allclose(compose_s, s))
s = np.random.randn(16).reshape(2, 2, 2, 2)
u, e, v = np.linalg.svd(s.reshape(2, 8), full_matrices=False)
a1 = u
s2 = np.tensordot(np.diag(e), v, axes=(-1, 0))
u, e, v = np.linalg.svd(s2.reshape(4, 4), full_matrices=False)
a2 = u.reshape(2, 2, 4)
s3 = np.tensordot(np.diag(e), v, axes=(-1, 0))
u, e, v = np.linalg.svd(s3.reshape(8, 2), full_matrices=False)
a3 = u.reshape(4, 2, 2)
v = np.tensordot(np.diag(e), v, axes=(-1, 0))
compose_s = reduce(lambda a, b: np.tensordot(a, b, axes=(-1, 0)), [a1, a2, a3, v])
print('total parameters: ', sum([a.size for a in [a1, a2, a3, v]]))
print(np.allclose(compose_s, s))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment