Skip to content

Instantly share code, notes, and snippets.

@v0dro
Last active May 17, 2019 13:02
Show Gist options
  • Save v0dro/2acb30181bf0d860d6190bd7a3c68985 to your computer and use it in GitHub Desktop.
Save v0dro/2acb30181bf0d860d6190bd7a3c68985 to your computer and use it in GitHub Desktop.
Low Rank matrix truncation algorithm as specified by Grasedyck.
import numpy as np
np.set_printoptions(precision=2, linewidth=300)
def lr(full, rank):
u, s, v = np.linalg.svd(full)
u = u[:, 0:rank]
s = np.diag(s)[0:rank, 0:rank]
v = v[0:rank, :]
return [u, s, v]
def mergeU(u1, u2):
return np.concatenate((u1, u2), axis=1)
def mergeS(s1, s2):
s1d = np.diag(s1)
s2d = np.diag(s2)
return np.diag(np.concatenate((s1d, s2d)))
def mergeV(s1, s2):
return np.concatenate((s1, s2), axis=0)
def lr_product(u1, s1, v1, u2, s2, v2):
sxv = np.matmul(s1, v1)
sxvxu = np.matmul(sxv, u2)
s2 = np.matmul(sxvxu, s2)
return [u1, s2, v2]
r1 = 12
r2 = 12
target_rank = 16
arr1 = np.full((16, 16), 2.0)
arr2 = np.full((16, 16), 1.0)
u1, s1, v1 = lr(arr1, r1)
u2, s2, v2 = lr(arr2, r2)
# Multiply LR matrices.
up, sp, vp = lr_product(u1, s1, v1, u2, s2, v2)
print("PRODUCT:")
print(np.matmul(np.matmul(up, sp), vp))
u0 = np.zeros(shape=(16, 12))
s0 = np.zeros(shape=(12, 12))
v0 = np.zeros(shape=(12, 16))
# u = mergeU(u0, up)
# s = mergeS(s0, sp)
# v = mergeV(v0, vp)
# Add LR matrices.
A = mergeU(u0, np.matmul(up, sp))
B = np.transpose(mergeV(v0, vp))
print("A:")
print(A.shape)
print("B:")
print(B.shape)
# Truncate the addition.
Qa, Ra = np.linalg.qr(A)
print("Qa:")
print(Qa.shape)
print(Qa)
print("Ra:")
print(Ra.shape)
print(Ra)
Qb, Rb = np.linalg.qr(B)
print("Qb:")
print(Qb.shape)
print("Rb:")
print(Rb.shape)
RaRbT = np.matmul(Ra, np.transpose(Rb))
print("RaRbT:")
print(RaRbT.shape)
Utemp, Stemp, Vtemp = np.linalg.svd(RaRbT)
print("Utemp:")
print(Utemp.shape)
print("Stemp:")
print(Stemp.shape)
Stemp = np.diag(Stemp)
U_hat = Utemp[:, 0:target_rank]
S_hat = Stemp[0:target_rank, 0:target_rank]
V_hat = np.transpose(Vtemp)[0:, 0:target_rank]
print(U_hat)
U = np.matmul(Qa, U_hat)
V = np.matmul(Qb, V_hat)
print(U)
print(S_hat)
print(np.transpose(V))
# print(U.shape)
# print(V.shape)
# print(S_hat.shape)
print("END PRODUCT:")
print(np.matmul(np.matmul(U, S_hat), np.transpose(V)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment