Skip to content

Instantly share code, notes, and snippets.

@IshitaTakeshi
Last active August 29, 2015 14:17
Show Gist options
  • Save IshitaTakeshi/102121ad70f5a632ee25 to your computer and use it in GitHub Desktop.
Save IshitaTakeshi/102121ad70f5a632ee25 to your computer and use it in GitHub Desktop.
The kalman filter
#!/usr/bin/python
# -*- coding: utf-8 -*-
import numpy as np
from matplotlib import pyplot
from numpy.random import multivariate_normal
class StateViewer(object):
def __init__(self, dt, n_iterations):
self.dt = dt
self.n_iterations = n_iterations
self.true_positions = []
self.true_velocities = []
self.estimated_positions = []
self.estimated_velocities = []
def add(self, x_true, x_est):
self.true_positions.append(x_true.item((0, 0)))
self.true_velocities.append(x_true.item((1, 0)))
self.estimated_positions.append(x_est.item((0, 0)))
self.estimated_velocities.append(x_est.item((1, 0)))
def show(self):
t = np.arange(0, self.n_iterations*dt, dt)
true_position_line = pyplot.plot(t, self.true_positions, 'b-')
estimated_position_line = pyplot.plot(t,
self.estimated_positions,
'b--')
true_velocity_line = pyplot.plot(t, self.true_velocities, 'g-')
estimated_velocity_line = pyplot.plot(t,
self.estimated_velocities,
'g--')
pyplot.legend((true_position_line[0], estimated_position_line[0],
true_velocity_line[0], estimated_velocity_line[0]),
('true position', 'estimated position',
'true velocity', 'estimated velocity'))
pyplot.show()
def get_driver_input():
return np.random.uniform(-1, 1)
dt = 0.1 #dt[s]ごとに状態を測定し、更新する
n_iterations = 2000
Q = np.identity(2) * 3
R = np.identity(2) * 3
B = np.matrix([dt**2/2, dt]).T
F = np.matrix([[1, dt],
[0, 1]])
x_true = np.matrix([0, 0]).T #初期位置は0[m]、速度も0[m/s]とする
x_est = x_true #初期状態はわかっている
#誤差行列
#これの対角成分の和を最小化するのが目的
P = np.matrix([[0.01, 0],
[0, 0.01]])
H = np.matrix([[2/340.0, 0],
[0, 2/340.0]])
viewer = StateViewer(dt, n_iterations)
for t in range(n_iterations):
u = get_driver_input()
w = multivariate_normal([0, 0], Q).reshape(2, 1)
x_true = F*x_true + B*u + w
v = multivariate_normal([0, 0], R).reshape(2, 1)
z = H*x_true + v
x_odo = F*x_est + B*u
P = F*P*F.T + Q
S = R + H*P*H.T
K = P + H.T*S.I
x_est = x_odo + K*(z-H*x_odo)
P = P - K*H*P
viewer.add(x_true, x_est)
viewer.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment