Skip to content

Instantly share code, notes, and snippets.

@taiga4112
Created December 6, 2016 09:44
Show Gist options
  • Save taiga4112/77c6117168fe5b6362930a34b0d92d77 to your computer and use it in GitHub Desktop.
Save taiga4112/77c6117168fe5b6362930a34b0d92d77 to your computer and use it in GitHub Desktop.
Extended Kalman filter sample for estimating state and parameters at the same time
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
def main():
# 初期化
T = 30 # 観測数
r = 10.0 # 半径
w = 1.0*10/180 * np.pi # 角速度[rad/s](公称値であり、データからこの値を推定する)
z = np.mat([[0.0],[-5.0],[w]]) # 初期位置+推測するパラメータの初期値
Z = [z] # 実際の状態+パラメータ推定値
Y = [z] # 観測+パラメータ推定値
U = np.mat([[r],[w],[0.0]]) # 操作量(推定するwについては逐次推定値を用いる)
# state x = f(z_,u,v), v~N(0,Q)
Q = np.mat([[0.5,0.0,0.0],[0.0,0.5,0.0],[0.0,0.0,0.1]])
# observation Y = z + w, w~N(0,R)
R = np.mat([[0.5,0.0,0.0],[0.0,0.5,0.0],[0.0,0.0,0.1]])
def f(t,z,u):
x0 = u[0,0]*z[2,0]*np.cos(z[2,0]*t)+z[0,0]
x1 = u[0,0]*z[2,0]*np.sin(z[2,0]*t)+z[1,0]
_w = z[2,0]
return np.mat([[x0],[x1],[_w]])
def Jf(t,z,u):
"""
解析的に求めるf(x)のヤコビ行列
"""
return np.mat([[u[0,0]*z[2,0]*np.cos(z[2,0]*t),0,0],[0,u[0,0]*z[2,0]*np.sin(z[2,0]*t),0],[0,0,1]])
# 観測データの生成
for t in range(T):
z = f(t,z,U)+np.random.multivariate_normal([0,0,0],Q,1).T
z[2,0] = w #観測データにおけるwは一定であるとして作成している
Z.append(z)
y = z + np.random.multivariate_normal([0,0,0],R,1).T
Y.append(y)
# EKF
_z = np.mat([[0.0],[-5.0],[w]])
Sigma = np.mat([[1,0,0],[0,1,0],[0,0,1]])
_Z = [_z] # 推定
for t in range(T):
# prediction
A = Jf(t,_z,U)
_z_ = f(t,_z,U)
Sigma_ = Q + A * Sigma * A.T
# update
C = np.mat([[1,0,0],[0,1,0],[0,0,1]])
yi = Y[t+1] - _z_
S = Sigma_ + R
G = Sigma_ * C.T * S.I
_z = _z_ + G * yi
print (G*yi)[2,0]
Sigma = Sigma_ - G * Sigma_
_Z.append(_z)
# 描画
plt.subplot(2, 1, 1)
a, b ,c = np.array(np.concatenate(Z,axis=1))
plt.plot(a,b,'rs-', label="X_correct")
a, b, c = np.array(np.concatenate(Y,axis=1))
plt.plot(a,b,'g^-', label="Y (=X+N(0,R))")
a, b, c = np.array(np.concatenate(_Z,axis=1))
plt.plot(a,b,'bo-', label="X_estimate")
plt.legend(bbox_to_anchor=(0.80, 0.00), loc='lower left', borderaxespad=0)
plt.axis('equal')
plt.subplot(2, 1, 2)
xx = np.arange(T+1)
plt.plot(xx,c*180/np.pi, label="w_estimate")
plt.legend(bbox_to_anchor=(0.80, 0.00), loc='lower left', borderaxespad=0)
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment