Skip to content

Instantly share code, notes, and snippets.

@iCorv
Forked from endolith/DFT_ANN.py
Created January 24, 2023 09:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iCorv/ca3405fb80e14c0733ac92c8ec73edab to your computer and use it in GitHub Desktop.
Save iCorv/ca3405fb80e14c0733ac92c8ec73edab to your computer and use it in GitHub Desktop.
Training neural network to implement discrete Fourier transform (DFT/FFT)

My third neural network experiment (second was FIR filter). DFT output is just a linear combination of inputs, so it should be implementable by a single layer with no activation function.

Topology of a 4-point complex DFT

Animation of weights being trained:

Neural network weights heatmap

Red are positive, blue are negative. The black squares (2336 out of 4096) are unused, and could be pruned out to save computation time (if I knew how to do that).

Even with pruning, it would be less efficient than an FFT, so if the FFT output is useful, probably best to do it externally and provide it as separate inputs?

This at least demonstrates that neural networks can figure out frequency content on their own, though, if it's useful to the problem.

The loss goes down for a while but then goes up. I don't know why:

loss vs epoch

"""
Train a neural network to implement the discrete Fourier transform
"""
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
import matplotlib.pyplot as plt
N = 32
batch = 10000
# Generate random input data and desired output data
sig = np.random.randn(batch, N) + 1j*np.random.randn(batch, N)
F = np.fft.fft(sig, axis=-1)
# First half of inputs/outputs is real part, second half is imaginary part
X = np.hstack([sig.real, sig.imag])
Y = np.hstack([F.real, F.imag])
# Create model with no hidden layers, same number of outputs as inputs.
# No bias needed. No activation function, since DFT is linear.
model = Sequential([Dense(N*2, input_dim=N*2, use_bias=False)])
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(X, Y, epochs=100, batch_size=100)
# Confirm that it works
data = np.arange(N)
def ANN_DFT(x):
if len(x) != N:
raise ValueError(f'Input must be length {N}')
pred = model.predict(np.hstack([x.real, x.imag])[np.newaxis])[0]
result = pred[:N] + 1j*pred[N:]
return result
ANN = ANN_DFT(data)
FFT = np.fft.fft(data)
print(f'ANN matches FFT: {np.allclose(ANN, FFT)}')
# Heat map of neuron weights
plt.imshow(model.get_weights()[0], vmin=-1, vmax=1, cmap='coolwarm')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment