Skip to content

Instantly share code, notes, and snippets.

@mryndzionek
Last active June 28, 2024 18:13
Show Gist options
  • Save mryndzionek/7dc8351115cbdf819faff694233f9267 to your computer and use it in GitHub Desktop.
Save mryndzionek/7dc8351115cbdf819faff694233f9267 to your computer and use it in GitHub Desktop.

Kalman filter PLL

import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from filterpy.kalman import (
ExtendedKalmanFilter,
UnscentedKalmanFilter,
JulierSigmaPoints,
)
from filterpy.common import Q_discrete_white_noise
class DDS:
def __init__(self, ts, f):
self.ts = ts
self.phi = 0.0
self.dphi = 2 * np.pi * ts * f
def update(self):
a = np.cos(self.phi)
self.phi += self.dphi
return a
def update_freq(self, f):
self.dphi = 2 * np.pi * ts * f
class EKFPLL1:
def __init__(self, ts, stddev_meas=0.1):
self.F = np.array([[1.0, ts, (ts**2) / 2], [0.0, 1.0, ts], [0.0, 0.0, 1.0]])
self.kf = ExtendedKalmanFilter(3, 1)
self.kf.x = np.array([0.0001, 0.001, 10])
self.kf.F = self.F
self.kf.P = np.eye(self.kf.dim_x) * 1e5
self.kf.Q = Q_discrete_white_noise(3, ts, 1)
self.kf.R = np.eye(1) * (stddev_meas**2)
def update(self, _t, y):
def hx(x):
return np.cos(x[0])
def hjx(x):
jac = np.array(
[
[
-np.sin(x[0]),
0.0,
0.0,
]
]
)
return jac
self.kf.predict_update(y, hjx, hx)
return self.kf.x
class EKFPLL2:
def __init__(self, ts, stddev_meas=0.1):
self.F = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, ts], [0.0, 0.0, 1.0]])
self.kf = ExtendedKalmanFilter(3, 1)
self.kf.x = np.array([0.0001, 0.001, 10])
self.kf.F = self.F
self.kf.P = np.eye(self.kf.dim_x) * 1e6
self.kf.Q = Q_discrete_white_noise(3, ts, 0.1)
self.kf.R = np.eye(1) * ((stddev_meas) ** 2)
def update(self, _t, y):
def hx(x, t):
return np.cos(2 * np.pi * x[1] * t + x[0])
def hjx(x, t):
dphi = -np.sin(2 * np.pi * x[1] * t + x[0])
dfreq = -2 * np.pi * t * np.sin(2 * np.pi * x[1] * t + x[0])
jac = np.array(
[
[
dphi,
dfreq,
0.0,
]
]
)
return jac
self.kf.predict_update(y, hjx, hx, (_t,), (_t,))
return self.kf.x
class UKFPLL1:
def __init__(self, ts, stddev_meas=0.1):
def fx(x, dt):
fn = 1.0 / (2 * ts)
x_n = np.dot(self.F, x)
# make sure phase and frequency don't converge to absurd values
# this might be a little bit dogdy, as no in-depth analysis was performed
x_n[0] = (x_n[0] + np.pi) % (2 * np.pi) - np.pi
x_n[1] = np.abs(x_n[1])
x_n[1] = x_n[1] % fn if x_n[1] > fn else x_n[1]
return x_n
def hx(x):
return np.array([np.cos(2 * np.pi * x[1] * self.t + x[0])])
self.F = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, ts], [0.0, 0.0, 1.0]])
self.ts = ts
self.t = 0.0
sp = JulierSigmaPoints(3)
self.kf = UnscentedKalmanFilter(3, 1, ts, hx, fx, sp)
self.kf.x = np.array([0.0001, 0.001, 0])
self.kf.P = np.eye(3) * 1e6
self.kf.Q = Q_discrete_white_noise(3, ts, 1.0)
self.kf.R = np.eye(1) * ((stddev_meas) ** 2)
def update(self, _t, y):
self.t = _t
self.kf.predict()
self.kf.update(y)
return self.kf.x
def plot(pll, time, ys, eval_fn, title, stddev=0.0, scale=1.0):
ys += np.random.normal(0, stddev, len(time))
xs = []
phase = []
dfreq = []
freq = []
stdev = []
for t, y in zip(time, ys):
x = pll.update(t, y)
xs.append(eval_fn(x, t))
phase.append((x[0] + np.pi) % (2 * np.pi) - np.pi)
freq.append(x[1] * scale)
dfreq.append(x[2] * scale)
stdev.append(np.sqrt(np.diag(pll.kf.P)))
_, axs = plt.subplots(
4, figsize=(40, 50), gridspec_kw={"height_ratios": [3, 1, 1, 1]}
)
for ax in axs:
ax.grid(True)
ax.set_xlim((0, np.max(time)))
axs[0].scatter(time, ys, marker="x", s=12, label="input signal")
axs[0].plot(time, xs, color="red", label="PLL output")
axs[0].legend(loc="lower left")
axs[0].set_title(title)
stdev = np.array(stdev).T
axs[1].fill_between(
time,
np.array(phase) + stdev[0],
np.array(phase) - stdev[0],
color="#AAAAAA60",
)
axs[1].plot(time, phase, label="phase [rad]")
axs[1].set_ylim(-np.pi, np.pi)
axs[1].legend(loc="lower left")
axs[2].fill_between(
time,
np.array(freq) + stdev[1],
np.array(freq) - stdev[1],
color="#AAAAAA60",
)
axs[2].plot(time, freq, label="frequency [Hz]")
axs[2].set_ylim(-200, 200)
axs[2].legend(loc="lower left")
axs[3].fill_between(
time,
np.array(dfreq) + stdev[2],
np.array(dfreq) - stdev[2],
color="#AAAAAA60",
)
axs[3].plot(time, dfreq, label="frequency change [Hz/s]")
axs[3].legend(loc="lower left")
axs[3].set_ylim(-100, 100)
axs[3].set_xlabel("time [s]")
plt.show()
np.random.seed(0)
ts = 0.001
fs = round(1 / ts)
meas_stddev = 0.1
# klf = EKFPLL1
# eval_fn = lambda x, _: np.cos(x[0])
# sf = 1.0 / (2 * np.pi)
klf = EKFPLL2
eval_fn = lambda x, t: np.cos(2 * np.pi * x[1] * t + x[0])
sf = 1.0
# klf = UKFPLL1
# eval_fn = lambda x, t: np.cos(2 * np.pi * x[1] * t + x[0])
# sf = 1.0
pll = klf(ts, meas_stddev)
time = np.arange(0, 15 * (1 / (fs / 20)), ts)
ys = np.cos(2 * np.pi * (fs / 20) * time + np.pi / 4)
plot(pll, time, ys, eval_fn, "Fixed frequency", meas_stddev, sf)
pll = klf(ts, meas_stddev)
time = np.linspace(0, 1, fs)
ys = signal.chirp(time, f0=fs / 50, f1=fs / 10, t1=1, method="linear", phi=np.pi / 4)
plot(pll, time, ys, eval_fn, "Frequency sweep", meas_stddev, sf)
dds = DDS(ts, 10)
ys = [dds.update() for i in range(len(time) // 3)]
dds.update_freq(20)
ys += [dds.update() for i in range(len(time) // 3)]
dds.update_freq(50)
ys += [dds.update() for i in range(len(time) // 3)]
ys += [dds.update()]
pll = klf(ts, meas_stddev)
plot(pll, time, ys, eval_fn, "Frequency jumps", meas_stddev, sf)
contourpy==1.2.1
cycler==0.12.1
filterpy==1.4.5
fonttools==4.53.0
kiwisolver==1.4.5
matplotlib==3.9.0
numpy==2.0.0
packaging==24.1
pillow==10.3.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
scipy==1.14.0
six==1.16.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment