Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active February 25, 2019 02:44
Show Gist options
  • Save kastnerkyle/caba86e86e51e8904e9b0787b9189197 to your computer and use it in GitHub Desktop.
Save kastnerkyle/caba86e86e51e8904e9b0787b9189197 to your computer and use it in GitHub Desktop.
# Author: Kyle Kastner
# License: BSD 3-Clause
# A port of the code from
# "Griffin-Lim Like Phase Recovery via Alternating Direction Method of Multipliers", Yoshiki Masuyama, Kohei Yatabe, Yasuhiro Oikawa
# https://ieeexplore.ieee.org/document/8552369
# https://codeocean.com/capsule/1284665/tree/v1
# into numpy
import numpy as np
import scipy.signal as sg
import scipy
from scipy.io import wavfile
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import copy
import itertools
# matlab / octave hann divides by windowsize not windowsize - 1 as in numpy
def matlab_hann(windowsize):
return np.array([.5 * (1 - np.cos((2 * np.pi * n) / (windowsize))) for n in range(windowsize)])
def make_windual(win, step):
# http://splab.cz/wp-content/uploads/2013/11/Gabor-dual-windows-using-convex-optimization.pdf
# calculate the canonical dual window
extra = np.zeros((step * (int(len(win) // step) + 1) - len(win),))
windual = np.concatenate([win, extra])
# matlab vs numpy order
windual = windual.reshape((-1, step)).T
windual = windual / np.sum(np.abs(windual) ** 2, axis=1, keepdims=True)
# match matlab indexing style and out shape
windual = windual.T.ravel()[:len(win)][:, None]
return windual
def gl_STFT(sig, win, skip, winLen, Ls):
sigLC = np.roll(sig, winLen // 2)
idx = np.arange(winLen, dtype="int32")[:, None] + np.arange(0, Ls-winLen + 1, skip, dtype="int32")[None]
block = sigLC.ravel()[idx] * win[:, None]
C = np.fft.fft(np.fft.ifftshift(block, axes=0), axis=0)
hWL = int(winLen // 2)
lcl = np.arange(hWL + 1)[:, None] * np.arange(C.shape[1])[None]
C = C[:hWL+1, :] * np.exp(-2j * np.pi * (np.mod(lcl * skip, winLen) / float(winLen)))
return C
def gl_ISTFT(C, win, skip, winLen, Ls):
hWL = int(winLen // 2)
lcl = np.arange(hWL + 1)[:, None] * np.arange(C.shape[1])[None] * skip
C = C * np.exp(2j * np.pi * (np.mod(lcl, winLen) / float(winLen)))
"""
# no symmetric flag in numpy, so need to make it conjugate symmetric
shp = C.shape
shp = (shp[0] - 2, shp[1])
cpad = np.zeros(shp)
C = np.concatenate([C, cpad], 0)
C_start_flipped = C[1:int(len(C) // 2), :][::-1]
C[-C_start_flipped.shape[0]:, :] = C_start_flipped
# now it should be conjugate symmetric
#C_start_flipped = flipud(C(2:floor(end/2), :));
#C(ceil(end/2)+2:end, :) = conj(CStartFlipped);
sigr = np.fft.fftshift(np.real(np.fft.ifft(C, axis=0)), axes=0) * win
"""
sigr = np.fft.fftshift(np.fft.irfft(C, axis=0), axes=0) * win
idx = np.arange(winLen)[:, None] + np.arange(0, Ls-winLen + 1, skip)[None]
idx2 = np.tile(np.arange(C.shape[1]), (winLen, 1))
# matlab does column 1st?
p = scipy.sparse.coo_matrix((sigr.T.ravel(), (idx.T.ravel(), idx2.T.ravel())))
sigr = np.array(p.sum(axis=1)).ravel()
return np.roll(sigr, -winLen // 2)
def GLA(X,A,Iter,win,windual,skip,winLen,Ls):
X = copy.deepcopy(X)
def pc2(X_, cmplx=False):
if cmplx:
return A * (X_ / np.abs(X_))
else:
return A * np.sign(X_)
def pc1(X_):
p1 = gl_ISTFT(X_, windual, skip, winLen, Ls)
p2 = gl_STFT(p1, win, skip, winLen, Ls)
return p2
for i in range(Iter):
#aa = pc2(X)
#aa1 = gl_ISTFT(aa, windual, skip, winLen, Ls)
#aa2 = gl_STFT(aa1, win, skip, winLen, Ls)
#from IPython import embed; embed(); raise ValueError()
X = pc1(pc2(X, True if i > 0 else False))
#X = pc1(aa)
return gl_ISTFT(X, windual, skip, winLen, Ls)
def FGLA(X,A,Iter,alpha,win,windual,skip,winLen,Ls):
X = copy.deepcopy(X)
def pc2(X_, cmplx=False):
if cmplx:
return A * (X_ / np.abs(X_))
else:
return A * np.sign(X_)
def pc1(X_):
p1 = gl_ISTFT(X_, windual, skip, winLen, Ls)
p2 = gl_STFT(p1, win, skip, winLen, Ls)
return p2
Y = X
for i in range(Iter):
Xold = X.copy()
X = pc1(pc2(Y, True if i > 0 else False))
Y = X + alpha * (X - Xold)
return gl_ISTFT(X, windual, skip, winLen, Ls)
def ADMMGLA(X,A,Iter,rho,win,windual,skip,winLen,Ls):
X = copy.deepcopy(X)
def pc2(X_, cmplx=False):
if cmplx:
return A * (X_ / np.abs(X_))
else:
return A * np.sign(X_)
def pc1(X_):
p1 = gl_ISTFT(X_, windual, skip, winLen, Ls)
p2 = gl_STFT(p1, win, skip, winLen, Ls)
return p2
Z = X.copy()
U = 0. * X
for i in range(Iter):
X = pc2(Z - U, True if i > 0 else False)
Y = X + U
Z = (rho * Y + pc1(Y)) / (1. + rho)
U = U + X - Z
return gl_ISTFT(X, windual, skip, winLen, Ls)
def normalize(x):
return x / np.max(np.abs(x))
# matlab / octave soundsc with gain control
def soundsc(X, gain=0.9, copy=True):
X = X.copy()
X_min = float(X.min())
X_max = float(X.max())
X_median = (X_max + X_min) / 2
X_scale = 2 / (X_max - X_min);
X_s = (X - X_median) * X_scale
X_s = X_s * gain
X_o = X_s * 2 ** 15
return X_o.astype('int16')
def wavsc(X, gain=0.9, copy=True):
X = X.copy()
X = np.maximum(-1., np.minimum(X, 1.0)) * gain
return (X * 2 ** 15).astype("int16")
n_iter = 10
# -- Hyperparameters ---------------------------------------------
# If alpha = 0.0, FGLA is equal to GLA
# If rho = 0.0, ADMMGLA is the proposed Alg. 1.
# If rho > 0.0, ADMMGLA becomes the proposed Alg. 2, and it coincides with
# GLA when rho = 1.0.
alpha = 0.99; #% hyperparamter for FGLA (0.00 ~ 1.00)
rho = 0.00; #% hyperparamter for ADMMGLA (0.00 ~ 1.00)
windowsize = 512
step = 216
fs, target = wavfile.read("target.wav")
target = target / float(2 ** 15)
win = matlab_hann(windowsize)
windual = make_windual(win, step)
# power of 2 to match the stft istft impl
Ls = (int((len(target) + 2 * (windowsize - step) - windowsize) // step) + 1) * step + windowsize
lpad = np.zeros((windowsize - step,))
# why 2 rpads?
rpad1 = np.zeros((Ls - len(target) - 2 * (windowsize - step),))
rpad2 = np.zeros((windowsize - step),)
target = np.concatenate([lpad, target, rpad1, rpad2])
C = gl_STFT(target,win,step,windowsize,Ls)
A = np.abs(C)
X0 = A
sig_gla = GLA(X0,A,n_iter,win,windual,step,windowsize,Ls)
sig_fgla = FGLA(X0,A,n_iter,alpha,win,windual,step,windowsize,Ls)
sig_admmgla = ADMMGLA(X0,A,n_iter,rho,win,windual,step,windowsize,Ls)
wavfile.write("output_gla.wav", fs, wavsc(normalize(sig_gla)))
wavfile.write("output_fgla.wav", fs, wavsc(normalize(sig_fgla)))
wavfile.write("output_admmgla.wav", fs, wavsc(normalize(sig_admmgla)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment