Skip to content

Instantly share code, notes, and snippets.

Last active May 27, 2024 00:52
Show Gist options
  • Save endolith/98863221204541bf017b6cae71cb0a89 to your computer and use it in GitHub Desktop.
Save endolith/98863221204541bf017b6cae71cb0a89 to your computer and use it in GitHub Desktop.
Training neural network to implement discrete Fourier transform (DFT/FFT)

My third neural network experiment (second was FIR filter). DFT output is just a linear combination of inputs, so it should be implementable by a single layer with no activation function.

Topology of a 4-point complex DFT

Animation of weights being trained:

Neural network weights heatmap

Red are positive, blue are negative. The black squares (2336 out of 4096) are unused, and could be pruned out to save computation time (if I knew how to do that).

Even with pruning, it would be less efficient than an FFT, so if the FFT output is useful, probably best to do it externally and provide it as separate inputs?

This at least demonstrates that neural networks can figure out frequency content on their own, though, if it's useful to the problem.

The loss goes down for a while but then goes up. I don't know why:

loss vs epoch

Train a neural network to implement the discrete Fourier transform
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
import matplotlib.pyplot as plt
N = 32
batch = 10000
# Generate random input data and desired output data
sig = np.random.randn(batch, N) + 1j*np.random.randn(batch, N)
F = np.fft.fft(sig, axis=-1)
# First half of inputs/outputs is real part, second half is imaginary part
X = np.hstack([sig.real, sig.imag])
Y = np.hstack([F.real, F.imag])
# Create model with no hidden layers, same number of outputs as inputs.
# No bias needed. No activation function, since DFT is linear.
model = Sequential([Dense(N*2, input_dim=N*2, use_bias=False)])
model.compile(loss='mean_squared_error', optimizer='adam'), Y, epochs=100, batch_size=100)
# Confirm that it works
data = np.arange(N)
def ANN_DFT(x):
if len(x) != N:
raise ValueError(f'Input must be length {N}')
pred = model.predict(np.hstack([x.real, x.imag])[np.newaxis])[0]
result = pred[:N] + 1j*pred[N:]
return result
ANN = ANN_DFT(data)
FFT = np.fft.fft(data)
print(f'ANN matches FFT: {np.allclose(ANN, FFT)}')
# Heat map of neuron weights
plt.imshow(model.get_weights()[0], vmin=-1, vmax=1, cmap='coolwarm')
Copy link

@rajb245 So are they replacing the dense connection layer with a butterfly connection layer?

Copy link

rajb245 commented Sep 21, 2020

@endolith That's one idea they explored in the paper, yes. They showed you can learn some underlying sparse structure of otherwise "fully connected" layers using their method. Using it like that is a new kind of matrix/model compression, and it contrasts a lot with the other approach that's popular in DL, pruning.

They also use the ideas of learning the weights of matrices with butterfly-sparsity-patterns and learning a permutation out a discrete set of choices to show that Adam does indeed find solutions that are the usual "fast" linear transforms from signal processing (FFT, fast DCT, fast DST, etc.).

Copy link

Hello. I tried to use your code and tried to find the difference between in computing the FFT using numpy and neural network and there was a big difference. Can you provide some insights on how to solve that? I saw multiple papers that are being published in which the authors mention that using the neural network as FFT reduces the time complexity greatly.

Copy link

endolith commented Nov 19, 2020


Yes, it's highly inefficient, as I said in the description and the comments. Even more so than a direct DFT because of all the zero weights being calculated unnecessarily. This isn't something that you should be doing. It was just an experiment while teaching myself neural nets.  If your neural net would benefit from frequency domain information, it's better to just do a numpy FFT and pass the output to the net (possibly adding a magnitude function afterward, since that nonlinearity was much harder to learn in my tests).

If you need the net to learn FFT-like transforms in general, look at the comment above about butterfly networks

Copy link

Can you shed some light on the motivation on this? Why do you want to approximate an existing algorithm with a (likely less accurate and less efficient) black box approach?

Copy link

endolith commented Jul 8, 2021

@sebraun-msr It was just an experiment. I thought that ANN nodes could be equivalent to signal processing nodes and tried implementing some things that way. I've said multiple times that you shouldn't actually do this. :)

For example, I learned that:

Copy link

endolith commented Mar 31, 2022

Copy link

ShriAka commented Dec 22, 2022

This work is cited here,

Copy link

@ShriAka That's cool, thanks for pointing it out :)

Copy link

The demo is using random noise data , if specified data,such as time series data or picture data is used , is it easy to learn the approprate weight?

Copy link

@Chenhubget It only learns at the frequencies present in the data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment