Skip to content

Instantly share code, notes, and snippets.

@insaneyilin
Created December 7, 2021 03:13
Show Gist options
  • Save insaneyilin/d1ac8487ba9bb5bd1d248841f53876f9 to your computer and use it in GitHub Desktop.
Save insaneyilin/d1ac8487ba9bb5bd1d248841f53876f9 to your computer and use it in GitHub Desktop.
IMM kalman filter
# reference: https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python
import copy
import numpy as np
from scipy.linalg import block_diag
from filterpy.kalman import IMMEstimator
from filterpy.kalman import KalmanFilter
from filterpy.common import Q_discrete_white_noise
def make_cv_filter(dt):
cvfilter = KalmanFilter(dim_x = 6, dim_z=4)
cvfilter.x = np.array([0., 0., 0., 0., 0., 0.])
cvfilter.P = np.identity(6)
cvfilter.R = np.identity(4)
cvfilter.Q = np.identity(6) * 0.001
cvfilter.F = np.identity(6)
cvfilter.F[0][2] = dt
cvfilter.F[1][3] = dt
cvfilter.H = np.eye(4, 6)
return cvfilter
def make_ca_filter(dt):
cafilter = KalmanFilter(dim_x = 6, dim_z=4)
cafilter.x = np.array([0., 0., 0., 0., 0., 0.])
cafilter.P = np.identity(6)
cafilter.R = np.identity(4)
cafilter.Q = np.identity(6) * 0.001
cafilter.F = np.identity(6)
cafilter.F[0][2] = dt
cafilter.F[1][3] = dt
cafilter.F[0][4] = 0.5 * dt * dt
cafilter.F[1][5] = 0.5 * dt * dt
cafilter.F[2][4] = dt
cafilter.F[3][5] = dt
cafilter.H = np.eye(4, 6)
return cafilter
def make_cp_filter(dt):
cpfilter = KalmanFilter(dim_x = 6, dim_z=4)
cpfilter.x = np.array([0., 0., 0., 0., 0., 0.])
cpfilter.P = np.identity(6)
cpfilter.R = np.identity(4)
cpfilter.Q = np.identity(6) * 0.001
cpfilter.F = np.identity(6)
cpfilter.H = np.eye(4, 6)
return cpfilter
dt = 0.1
cv_model = make_cv_filter(dt)
# print(cv_model)
ca_model = make_ca_filter(dt)
cp_model = make_cp_filter(dt)
# model weights
mu = np.array([0.3, 0.3, 0.4])
M = np.array([[0.8, 0.1, 0.1], [0.6, 0.2, 0.2], [0.1, 0.1, 0.8]])
filters = [cv_model, ca_model, cp_model]
bank = IMMEstimator(filters, mu, M)
start_px = 0.0
start_py = 0.0
vx = 2.0
vy = 0.5
bank.x = np.array([0., 0., 2., 0.5, 0., 0.])
for i in range(1, 100):
z = np.array([(start_px + i * dt * vx), (start_py + i * dt * vy), vx, vy])
bank.predict()
bank.update(z)
print(bank.x)
print(bank.mu)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment