Last active
February 25, 2019 02:44
-
-
Save kastnerkyle/caba86e86e51e8904e9b0787b9189197 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Kyle Kastner | |
# License: BSD 3-Clause | |
# A port of the code from | |
# "Griffin-Lim Like Phase Recovery via Alternating Direction Method of Multipliers", Yoshiki Masuyama, Kohei Yatabe, Yasuhiro Oikawa | |
# https://ieeexplore.ieee.org/document/8552369 | |
# https://codeocean.com/capsule/1284665/tree/v1 | |
# into numpy | |
import numpy as np | |
import scipy.signal as sg | |
import scipy | |
from scipy.io import wavfile | |
import matplotlib | |
matplotlib.use("TkAgg") | |
import matplotlib.pyplot as plt | |
import copy | |
import itertools | |
# matlab / octave hann divides by windowsize not windowsize - 1 as in numpy | |
def matlab_hann(windowsize): | |
return np.array([.5 * (1 - np.cos((2 * np.pi * n) / (windowsize))) for n in range(windowsize)]) | |
def make_windual(win, step): | |
# http://splab.cz/wp-content/uploads/2013/11/Gabor-dual-windows-using-convex-optimization.pdf | |
# calculate the canonical dual window | |
extra = np.zeros((step * (int(len(win) // step) + 1) - len(win),)) | |
windual = np.concatenate([win, extra]) | |
# matlab vs numpy order | |
windual = windual.reshape((-1, step)).T | |
windual = windual / np.sum(np.abs(windual) ** 2, axis=1, keepdims=True) | |
# match matlab indexing style and out shape | |
windual = windual.T.ravel()[:len(win)][:, None] | |
return windual | |
def gl_STFT(sig, win, skip, winLen, Ls): | |
sigLC = np.roll(sig, winLen // 2) | |
idx = np.arange(winLen, dtype="int32")[:, None] + np.arange(0, Ls-winLen + 1, skip, dtype="int32")[None] | |
block = sigLC.ravel()[idx] * win[:, None] | |
C = np.fft.fft(np.fft.ifftshift(block, axes=0), axis=0) | |
hWL = int(winLen // 2) | |
lcl = np.arange(hWL + 1)[:, None] * np.arange(C.shape[1])[None] | |
C = C[:hWL+1, :] * np.exp(-2j * np.pi * (np.mod(lcl * skip, winLen) / float(winLen))) | |
return C | |
def gl_ISTFT(C, win, skip, winLen, Ls): | |
hWL = int(winLen // 2) | |
lcl = np.arange(hWL + 1)[:, None] * np.arange(C.shape[1])[None] * skip | |
C = C * np.exp(2j * np.pi * (np.mod(lcl, winLen) / float(winLen))) | |
""" | |
# no symmetric flag in numpy, so need to make it conjugate symmetric | |
shp = C.shape | |
shp = (shp[0] - 2, shp[1]) | |
cpad = np.zeros(shp) | |
C = np.concatenate([C, cpad], 0) | |
C_start_flipped = C[1:int(len(C) // 2), :][::-1] | |
C[-C_start_flipped.shape[0]:, :] = C_start_flipped | |
# now it should be conjugate symmetric | |
#C_start_flipped = flipud(C(2:floor(end/2), :)); | |
#C(ceil(end/2)+2:end, :) = conj(CStartFlipped); | |
sigr = np.fft.fftshift(np.real(np.fft.ifft(C, axis=0)), axes=0) * win | |
""" | |
sigr = np.fft.fftshift(np.fft.irfft(C, axis=0), axes=0) * win | |
idx = np.arange(winLen)[:, None] + np.arange(0, Ls-winLen + 1, skip)[None] | |
idx2 = np.tile(np.arange(C.shape[1]), (winLen, 1)) | |
# matlab does column 1st? | |
p = scipy.sparse.coo_matrix((sigr.T.ravel(), (idx.T.ravel(), idx2.T.ravel()))) | |
sigr = np.array(p.sum(axis=1)).ravel() | |
return np.roll(sigr, -winLen // 2) | |
def GLA(X,A,Iter,win,windual,skip,winLen,Ls): | |
X = copy.deepcopy(X) | |
def pc2(X_, cmplx=False): | |
if cmplx: | |
return A * (X_ / np.abs(X_)) | |
else: | |
return A * np.sign(X_) | |
def pc1(X_): | |
p1 = gl_ISTFT(X_, windual, skip, winLen, Ls) | |
p2 = gl_STFT(p1, win, skip, winLen, Ls) | |
return p2 | |
for i in range(Iter): | |
#aa = pc2(X) | |
#aa1 = gl_ISTFT(aa, windual, skip, winLen, Ls) | |
#aa2 = gl_STFT(aa1, win, skip, winLen, Ls) | |
#from IPython import embed; embed(); raise ValueError() | |
X = pc1(pc2(X, True if i > 0 else False)) | |
#X = pc1(aa) | |
return gl_ISTFT(X, windual, skip, winLen, Ls) | |
def FGLA(X,A,Iter,alpha,win,windual,skip,winLen,Ls): | |
X = copy.deepcopy(X) | |
def pc2(X_, cmplx=False): | |
if cmplx: | |
return A * (X_ / np.abs(X_)) | |
else: | |
return A * np.sign(X_) | |
def pc1(X_): | |
p1 = gl_ISTFT(X_, windual, skip, winLen, Ls) | |
p2 = gl_STFT(p1, win, skip, winLen, Ls) | |
return p2 | |
Y = X | |
for i in range(Iter): | |
Xold = X.copy() | |
X = pc1(pc2(Y, True if i > 0 else False)) | |
Y = X + alpha * (X - Xold) | |
return gl_ISTFT(X, windual, skip, winLen, Ls) | |
def ADMMGLA(X,A,Iter,rho,win,windual,skip,winLen,Ls): | |
X = copy.deepcopy(X) | |
def pc2(X_, cmplx=False): | |
if cmplx: | |
return A * (X_ / np.abs(X_)) | |
else: | |
return A * np.sign(X_) | |
def pc1(X_): | |
p1 = gl_ISTFT(X_, windual, skip, winLen, Ls) | |
p2 = gl_STFT(p1, win, skip, winLen, Ls) | |
return p2 | |
Z = X.copy() | |
U = 0. * X | |
for i in range(Iter): | |
X = pc2(Z - U, True if i > 0 else False) | |
Y = X + U | |
Z = (rho * Y + pc1(Y)) / (1. + rho) | |
U = U + X - Z | |
return gl_ISTFT(X, windual, skip, winLen, Ls) | |
def normalize(x): | |
return x / np.max(np.abs(x)) | |
# matlab / octave soundsc with gain control | |
def soundsc(X, gain=0.9, copy=True): | |
X = X.copy() | |
X_min = float(X.min()) | |
X_max = float(X.max()) | |
X_median = (X_max + X_min) / 2 | |
X_scale = 2 / (X_max - X_min); | |
X_s = (X - X_median) * X_scale | |
X_s = X_s * gain | |
X_o = X_s * 2 ** 15 | |
return X_o.astype('int16') | |
def wavsc(X, gain=0.9, copy=True): | |
X = X.copy() | |
X = np.maximum(-1., np.minimum(X, 1.0)) * gain | |
return (X * 2 ** 15).astype("int16") | |
n_iter = 10 | |
# -- Hyperparameters --------------------------------------------- | |
# If alpha = 0.0, FGLA is equal to GLA | |
# If rho = 0.0, ADMMGLA is the proposed Alg. 1. | |
# If rho > 0.0, ADMMGLA becomes the proposed Alg. 2, and it coincides with | |
# GLA when rho = 1.0. | |
alpha = 0.99; #% hyperparamter for FGLA (0.00 ~ 1.00) | |
rho = 0.00; #% hyperparamter for ADMMGLA (0.00 ~ 1.00) | |
windowsize = 512 | |
step = 216 | |
fs, target = wavfile.read("target.wav") | |
target = target / float(2 ** 15) | |
win = matlab_hann(windowsize) | |
windual = make_windual(win, step) | |
# power of 2 to match the stft istft impl | |
Ls = (int((len(target) + 2 * (windowsize - step) - windowsize) // step) + 1) * step + windowsize | |
lpad = np.zeros((windowsize - step,)) | |
# why 2 rpads? | |
rpad1 = np.zeros((Ls - len(target) - 2 * (windowsize - step),)) | |
rpad2 = np.zeros((windowsize - step),) | |
target = np.concatenate([lpad, target, rpad1, rpad2]) | |
C = gl_STFT(target,win,step,windowsize,Ls) | |
A = np.abs(C) | |
X0 = A | |
sig_gla = GLA(X0,A,n_iter,win,windual,step,windowsize,Ls) | |
sig_fgla = FGLA(X0,A,n_iter,alpha,win,windual,step,windowsize,Ls) | |
sig_admmgla = ADMMGLA(X0,A,n_iter,rho,win,windual,step,windowsize,Ls) | |
wavfile.write("output_gla.wav", fs, wavsc(normalize(sig_gla))) | |
wavfile.write("output_fgla.wav", fs, wavsc(normalize(sig_fgla))) | |
wavfile.write("output_admmgla.wav", fs, wavsc(normalize(sig_admmgla))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment