Skip to content

Instantly share code, notes, and snippets.

@Piyush3dB
Created September 18, 2015 15:01
Show Gist options
  • Save Piyush3dB/336ea12a295538eaeb3b to your computer and use it in GitHub Desktop.
Save Piyush3dB/336ea12a295538eaeb3b to your computer and use it in GitHub Desktop.
-- Import packages
require 'torch'
require 'gnuplot'
--------------------------------------
-- Define the Kalman filter 'class' --
--------------------------------------
-- Constructor
Kalman = {}
Kalman.__index = Kalman
function Kalman.create(name, X0, S0)
-- initialize our object
local instance = {}
instance.name = name
instance.X = X0
instance.S = S0
-- make Kalman handle lookup
setmetatable(instance,Kalman)
print("Create a " .. name)
return instance
end
-- 'print' member function
function Kalman:printName()
print(self.name)
end
-- 'project' member function
function Kalman:project(Gt, Wt)
-- Propagate from past to current using estimates
print("project" .. self.name)
-- Mean propagation
self.GX = Gt*self.X
-- Variance propagation
self.Rt = Gt*self.S*Gt:t() + Wt
end
-- 'correct' member function
function Kalman:correct(Ft, Vt, Yt)
-- Correct current using measurement
print("correct" .. self.name)
-- Kalman Gain computation
self.Kt = self.Rt*Ft:t() * torch.inverse(Vt + Ft*self.Rt*Ft:t())
-- Mean correction
self.X = self.GX + self.Kt*(Yt - self.GX)
-- Variance correction
self.S = self.Rt - torch.inverse(Vt + Ft*self.Rt*Ft:t()) * Ft*self.Rt
end
---------------------------------------------
-- Kalman filter simulation --
---------------------------------------------
-- Number of mesurements
N = 1
-- Initial System state Mean (X0) and Variance (S0)
X = torch.Tensor(N,1):zero()
S = torch.Tensor(N,1):zero()
X[1] = 10 -- Theta
S[1] = 1 -- Sigma
-- Matrices
Gt = torch.Tensor(N,1):zero()
Ft = torch.Tensor(N,1):zero()
Gt[1] = 1 -- State
Ft[1] = 1 -- Measurement
-- Variances
Wt = torch.Tensor(N,1):zero()
Vt = torch.Tensor(N,1):zero()
Wt[1] = 1 -- State
Vt[1] = 2 -- Measurement
-- Measurement
Yt = torch.Tensor(N,1):zero()
-- create and use a Kalman filter
KS = Kalman.create("Simple Kalman", X, S)
KS:printName()
-- single iteration here
KS:project(Gt, Wt)
Yt[1] = 10
KS:correct(Ft, Vt, Yt)
--
print(KS)
-- End of simple-kalman.lua
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment