Skip to content

Instantly share code, notes, and snippets.

@kumanna
Created March 6, 2024 06:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kumanna/dbf490004c0ec22e3b2c0aa12f46e423 to your computer and use it in GitHub Desktop.
Save kumanna/dbf490004c0ec22e3b2c0aa12f46e423 to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.linalg as la
import numpy.linalg as LA
import matplotlib.pyplot as plt
def gradient_descent(V1, V2, subopt=False, mu=1, epsilon=1e-11, permute=True):
""" Performs gradient descent over flag manifold. Attempts to find the
geodesic with the minimum length between two (given) square-matrices on the manifold """
V1 = np.matrix(V1)
V2 = np.matrix(V2)
V0 = V1.H*V2
[Nt,Nr] = V0.shape
## Permutation invariance step
if permute:
P = np.matrix(find_permutation_matrix(np.eye(Nr),V0,ret_dist=False))
# pdb.set_trace()
V0 = V0*P
else:
P=np.matrix(np.eye(Nr))
# Making diagonal elements real for V0
D0 = np.matrix(np.diag(np.exp(-1j * np.angle(np.diag(np.array(V0))))))
V = np.matrix(V0)*D0
# Initial value of B
B = np.matrix(la.logm(V))
B_orig = B
if subopt:
return B_orig, D0.H
# Cost-function (or metric)
metric = la.norm(np.diag(B))
## Algorithm (assumption: Nt<Nr)
phi = np.zeros(Nr, dtype=complex)
D = np.matrix(np.eye(Nt))
steps=0
while metric > epsilon:
phi = phi - mu * np.imag(np.diag(B)) #Parameter update step
D = np.matrix(np.diag(np.exp(1j * phi))) #Diagonal element
B = np.matrix(la.logm(V*D)) #Updated B
metric = la.norm(np.diag(B)) #Cost-function re-evaluation step
steps += 1
return B, D.H*D0.H*P.H
N = 3
V1 = np.eye(N)
H = (np.random.randn(N, N) + np.random.randn(N, N) * 1j)
_, _, V2 = np.linalg.svd(H)
B, _ = gradient_descent(V1, V2, subopt=False, mu=0.01, epsilon=1e-7, permute=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment