Skip to content

Instantly share code, notes, and snippets.

@endolith
Last active February 6, 2024 19:15
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save endolith/1de9a8700f72b97974a2e93b0fba316a to your computer and use it in GitHub Desktop.
Save endolith/1de9a8700f72b97974a2e93b0fba316a to your computer and use it in GitHub Desktop.
Neural network learning FIR filter

My second experiment with Keras (first was single neurons).

Hypothesis: Each output sample of an FIR filter is just a sum of weighted input samples taken from a small chunk of the input:

Which is the same structure as a neural net (assuming no activation function):

So it should be able to learn the FIR coefficients by learning from chunks of signal before and after filtering, right?

Conclusion: Yep, it works great.

"""
Train a neural network to learn an FIR filter.
Created on Fri Aug 3 15:00:40 2018
"""
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import Callback
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from soundfile import read
"""
Generate or load a signal to use as input data
"""
# Only learns at the frequencies present in the signal
# https://en.wikipedia.org/wiki/File:Short-beaked_Echidna.ogg
sig, fs = read('echidna.wav')
# Learns at all frequencies with white noise
# sig, fs = np.random.randn(10000), 10000
"""
Create the FIR filter for the ANN to copy
"""
numtaps = 51
# b = signal.firwin(numtaps, 1, fs=fs)
# b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris',
# pass_zero=False)
b = signal.firwin(numtaps, cutoff=[6000, 11000], fs=fs,
window='blackmanharris', pass_zero=False)
# TODO: Use an IIR filter and have ANN approximate it as best it can
"""
Training data is chunks of input and output of FIR filter
"""
# filtered = signal.lfilter(b, 1.0, sig)
filtered = signal.convolve(sig, b, mode='valid')
def rolling_window(a, window):
"""
Return chunks of signal `a` of size `window`, incremented by 1 each time.
https://gist.github.com/codehacken/708f19ae746784cef6e68b037af65788
"""
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
X = rolling_window(sig, numtaps)
Y = filtered # Filter outputs 1 sample for each chunk of input samples
# plt.plot(X[0])
# plt.plot(Y[0])
"""
Create model
Initializer matters because signal might have missing areas of spectrum, and
model will not learn there. So if the initial guess is all zeros, those areas
of the spectrum will stay silenced, while the passband is "built up" for
frequencies that are present.
"""
model = Sequential([
Dense(1, input_dim=numtaps, use_bias=False,
# kernel_initializer='random_normal', # typical usage
# kernel_initializer='ones', # boxcar window = running average
kernel_initializer='zeros', # nothing (good for non-white input)
)
])
model.summary()
initial = model.get_weights()
print('Initial weights:')
print(initial)
"""
Make block diagram of network (not from tutorial)
"""
from tensorflow.keras.utils import plot_model
plot_model(model, to_file='model.png', show_shapes=True)
"""
Make graph diagram of network (not from tutorial)
Viewable with `tensorboard --logdir="logs"`
"""
import tensorflow as tf
# TODO: This isn't working like it used to. Replace with TF2.0 conventions.
with tf.compat.v1.Session() as sess:
writer = tf.compat.v1.summary.FileWriter('logs', sess.graph)
writer.close()
"""
Node-level graph
"""
# Working version: https://github.com/endolith/ann-visualizer
# from ann_visualizer.visualize import ann_viz
# ann_viz(model, title="Learned FIR filter")
# https://github.com/Dicksonchin93/keras-architecture-visualizer/
# from keras_architecture_visualizer import KerasArchitectureVisualizer
# vis = KerasArchitectureVisualizer()
# vis.visualize(model)
# Compile model
model.compile(loss='mean_squared_error',
optimizer='adam',
)
class LossHistory(Callback):
def on_train_begin(self, logs={}):
self.losses = []
# Could plot the convergence here
def on_batch_end(self, batch, logs={}):
self.losses.append(logs.get('loss'))
def on_epoch_end(self, batch, logs={}):
pass
# Could plot the convergence here
history = LossHistory()
# Fit the model
print("Fitting...")
model.fit(X, Y, epochs=35, batch_size=100, callbacks=[history])
# evaluate the model
print("Evaluating...")
scores = model.evaluate(X, Y)
print(scores*100) # percent??
final = model.get_weights()
fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, num='kernel', sharex=True)
ax1.plot(b, '.-', label='Filter', alpha=0.5, c='gray')
ax0.plot(initial[0], '.-', label='Initial')
ax1.plot(final[0], '.-', label='Learned')
ax0.grid(True, color='0.7', linestyle='-', which='major')
ax0.grid(True, color='0.9', linestyle='-', which='minor')
ax1.grid(True, color='0.7', linestyle='-', which='major')
ax1.grid(True, color='0.9', linestyle='-', which='minor')
ax0.set_title('Kernel')
ax0.legend()
ax1.legend()
plt.figure('frequency response')
w, h = signal.freqz(b, [1.0])
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Filter',
alpha=0.5, c='gray')
w, h = signal.freqz(initial[0], [1.0])
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Initial')
w, h = signal.freqz(final[0], [1.0])
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Learned', alpha=0.5)
plt.grid(True, color='0.7', linestyle='-', which='major')
plt.grid(True, color='0.9', linestyle='-', which='minor')
plt.xlabel('Frequency [Hz]')
plt.ylabel('Response [dB]')
plt.xlim(None, fs/2)
plt.title('Frequency response')
plt.legend()
plt.figure('loss')
plt.semilogy(history.losses)
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.grid(True, which="both")
plt.title('Loss')
@endolith
Copy link
Author

endolith commented Mar 29, 2022

model

frequency_response

kernel

best BP 1001 FR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment