Created
May 12, 2019 13:01
-
-
Save prabindh/cf1008baef52db0b7ab7d74d5200255a to your computer and use it in GitHub Desktop.
Phase learning https://stackoverflow.com/questions/56098924
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import Sequential | |
from keras.layers import Dense | |
import numpy as np | |
def normalize_angles(phases): | |
phases = phases + np.pi | |
phases /= (2 * np.pi) | |
return phases | |
def build_fourier_mnist(): | |
mnist = np.load("train_features.npy") #MNIST as is. | |
fourier_mnist = np.zeros(mnist.shape, dtype=np.complex) | |
for i in range(mnist.shape[0]): | |
current_image = np.reshape(mnist[i, :], (28, 28)) #Turn to matrix so we can perform 2d fft | |
fourier_current_image = np.fft.fft2(current_image) #perform 2d fft | |
fourier_mnist[i, :] = np.reshape(fourier_current_image,(1, 784)) #flatten and save to new matrix | |
return fourier_mnist | |
fourier_mnist = build_fourier_mnist() | |
amplitudes = np.abs(fourier_mnist) | |
phases = normalize_angles(np.angle(fourier_mnist)) | |
model = Sequential() | |
model.add(Dense(784, input_dim=amplitudes.shape[1], activation='sigmoid')) | |
model.add(Dense(784, activation='sigmoid')) | |
model.add(Dense(784, activation='sigmoid')) | |
model.add(Dense(784, activation='sigmoid')) | |
model.add(Dense(phases.shape[1], activation='sigmoid')) | |
#Compile model | |
model.compile(loss='mean_squared_error', optimizer='adam') | |
#Fit the model | |
model.fit(amplitudes, phases, epochs=400, batch_size=100) | |
model.save("phase_retriever2.h5") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment