Skip to content

Instantly share code, notes, and snippets.

@sinhrks
Last active August 29, 2015 14:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sinhrks/4023ccd415745c9584c5 to your computer and use it in GitHub Desktop.
Save sinhrks/4023ccd415745c9584c5 to your computer and use it in GitHub Desktop.
KalmanFilter (multivariate) with ggplot2 animation Pt2
set.seed(1)
# 観測系列のサンプルサイズ
n <- 30
# 加速度
a <-rep(0.3, n)
v <- cumsum(0.5 * a)
x <- cumsum(v)
# 真の値
actual <- data.frame(x = x, v = v)
# 観測系列の次元
dim <- ncol(actual)
# 観測される値 (誤差は標準偏差0.05の正規分布とする)
observed <- cbind(x + 0.5 * rnorm(n, sd = 0.05),
v + rnorm(n, sd = 0.05))
# 予測結果保存用のmatrix
xhat.m <- matrix(0, nrow = n, ncol = dim)
# 補正結果保存用のmatrix
xhat <- matrix(0, nrow = n, ncol = dim)
# F
FM <- matrix(c(1, 0, 1, 1), ncol = 2)
# 結果保存用の array
# P[, , k] K[, , k] に 時刻 k 時点の計算結果を保存
P <- array(c(0.0, 0.1, 0, 0.1), dim = c(dim, dim, n))
K <- array(0, dim = c(dim, dim, n))
# 誤差
Q <- diag(0.01, nrow = dim, ncol = dim)
R <- diag(0.01, nrow = dim, ncol = dim)
# 単位行列
I <- diag(1, ncol = dim, nrow = dim)
for (k in seq(2, n)) {
# predict
xhat.m[k, ] <- FM %*% xhat[k-1, ]
P.m <- FM %*% P[, , k-1] %*% t(FM) + Q
# update
S <- R + P.m
K[, , k] <- P.m %*% solve(S)
xhat[k, ] <- xhat.m[k, ] + K[, , k] %*% (observed[k, ] - xhat.m[k, ])
P[, , k] <- (I - K[, , k]) * P.m
}
library(animation)
library(ggplot2)
library(gridExtra)
time = seq(1, n)
d <- as.data.frame(cbind(time, actual, observed, xhat, xhat.m))
colnames(d) <- c('time', 'actual.x', 'actual.v',
'observed.x', 'observed.v',
'fitted.x', 'fitted.v',
'predict.x', 'predict.v')
saveGIF({
disp.m <- 10
frames <- 4
xdiff <- diff(x) / frames
vdiff <- diff(v) / frames
xbase <- x[1:length(x) - 1]
xmax <- as.vector(rbind(xbase, xbase + xdiff, xbase + 2 * xdiff, xbase + 3 * xdiff))
vbase <- v[1:length(v) - 1]
vmax <- as.vector(rbind(vbase, vbase + vdiff, vbase + 2 * vdiff, vbase + 3 * vdiff))
for (i in seq(3, n-2, by = 1)) {
current <- head(d, i)
current <- tail(current, disp.m)
current.n <- nrow(current)
prev <- head(current, current.n - 1)
prev.x <- prev[nrow(prev), 'fitted.x']
prev.v <- prev[nrow(prev), 'fitted.v']
latest <- tail(current, 1)
p <- P[, , i]
k <- K[, , i]
xmin <- current[1, 'actual.x']
ymin <- current[1, 'actual.v']
for (j in seq(1, frames)) {
# 1: actual
p1 <- ggplot() +
geom_path(data = d, aes(x = actual.x, y = actual.v), colour = 'red') +
geom_point(data = prev, aes(x = observed.x, y = observed.v)) +
geom_path(data = prev, aes(x = fitted.x, y = fitted.v), colour = 'blue') +
xlim(xmin, xmax[i*4+j]) + ylim(ymin, vmax[i*4+j]) +
xlab('x') + ylab('v')
# 2: predict
if (j >= 2) {
x <- latest$predict.x
v <- latest$predict.v
p1 <- p1 +
annotate(geom = 'segment', x = prev.x, y = prev.v, xend = x, yend = v,
colour = 'blue', linetype = 'dashed') +
annotate(geom = 'text', x = x, y = v,
label = 'Predict', colour = 'blue', hjust = -0.2)
}
# 3: observation
if (j >= 3) {
x = latest$observed.x
v = latest$observed.v
p1 <- p1 +
annotate(geom = 'point', x = x, y = v, size = 5) +
annotate(geom = 'text', x = x, y = v,
label = 'Observed', colour = 'black', hjust = -0.2)
}
# 4: update
if (j >= 4) {
x <- latest$fitted.x
v <- latest$fitted.v
p1 <- p1 +
annotate(geom = 'segment', x = prev.x, y = prev.v, xend = x, yend = v,
colour = 'blue') +
annotate(geom = 'text', x = x, y = v,
label = 'Updated', colour = 'blue', hjust = -0.2)
}
# plot P and kalman gain
xunit <- (xmax[i*4+j] - xmin) / 20
yunit <- (vmax[i*4+j] - ymin) / 20
unit <- matrix(c(0, 1, 1, 0, 0,
0, 0, 1, 1, 0), ncol = 2)
dp <- det(p)
dk <- det(k)
pk <- data.frame(cbind(unit %*% p / sqrt(dp),
unit %*% k / sqrt(dk)))
colnames(pk) <- c('px', 'py', 'kx', 'ky')
pk$px <- pk$px * xunit + xmin + xunit * 3
pk$py <- pk$py * yunit + ymin + yunit * 16
pk$kx <- pk$kx * xunit + xmin + xunit * 7
pk$ky <- pk$ky * yunit + ymin + yunit * 16
p1 <- p1 +
geom_path(data = pk, aes(x = px, y = py), colour = 'red') +
annotate(geom = 'text', x = xmin + xunit * 3,
y = ymin + yunit * 15, colour = 'red',
label = paste0('det(P) = ', signif(dp, digits = 3))) +
geom_path(data = pk, aes(x = kx, y = ky), colour = 'green') +
annotate(geom = 'text', x = xmin + xunit * 8,
y = ymin + yunit * 15, colour = 'green',
label = paste0('det(K) = ', signif(dk, digits = 3)))
print(p1)
}
}
}, interval = 0.3, movie.name = "kalmanfilter04_01.gif",
ani.width = 600, ani.height = 600)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment