Skip to content

Instantly share code, notes, and snippets.

@agramfort
Created July 14, 2024 08:58
Show Gist options
  • Save agramfort/9eab9f7488d7a32f77745875375ec2ad to your computer and use it in GitHub Desktop.
Save agramfort/9eab9f7488d7a32f77745875375ec2ad to your computer and use it in GitHub Desktop.
NLMS: Normalized Least-mean-square
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import lfilter
# Step 1: Create a synthetic dataset
np.random.seed(42) # For reproducibility
N = 1000 # Number of samples
x = np.random.randn(N) # Input signal (random noise)
w_true = np.array([0.5, -0.3, 0.1]) # True filter coefficients
e_ = 0.01 * np.random.randn(N) # Noise
d = lfilter(w_true, 1, x) + e_ # Desired signal
# Step 2: NLMS Algorithm Implementation
def nlms(x, d, mu=0.01, eps=1e-6, M=3):
n = len(x)
w = np.zeros(M) # Initial filter coefficients
e = np.zeros(n) # Error signal
for i in range(M, n):
x_i = x[i-M+1:i+1]
y = np.dot(x_i, w)
e[i] = d[i] - y
norm = np.dot(x_i, x_i) + eps
w += (mu / norm) * x_i * e[i]
return w[::-1], e
# Run NLMS
w_est, e = nlms(x, d)
# Step 3: Performance Metrics and Plots
# Mean Squared Error (MSE)
mse = np.mean(e**2)
# Plotting
plt.figure(figsize=(14, 5))
plt.subplot(1, 2, 1)
plt.plot(e, label='Error Signal')
plt.title('Error Signal Over Time')
plt.xlabel('Sample')
plt.ylabel('Error')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(w_true, label='True Coefficients', marker='o')
plt.plot(w_est, label='Estimated Coefficients', marker='x')
plt.title('Filter Coefficients')
plt.xlabel('Coefficient Index')
plt.ylabel('Value')
plt.legend()
plt.tight_layout()
plt.show()
# Interpretations
print(f"Final Estimated Coefficients: {w_est}")
print(f"Mean Squared Error: {mse}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment