Last active
August 29, 2015 14:08
-
-
Save sinhrks/4023ccd415745c9584c5 to your computer and use it in GitHub Desktop.
KalmanFilter (multivariate) with ggplot2 animation Pt2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set.seed(1) | |
# 観測系列のサンプルサイズ | |
n <- 30 | |
# 加速度 | |
a <-rep(0.3, n) | |
v <- cumsum(0.5 * a) | |
x <- cumsum(v) | |
# 真の値 | |
actual <- data.frame(x = x, v = v) | |
# 観測系列の次元 | |
dim <- ncol(actual) | |
# 観測される値 (誤差は標準偏差0.05の正規分布とする) | |
observed <- cbind(x + 0.5 * rnorm(n, sd = 0.05), | |
v + rnorm(n, sd = 0.05)) | |
# 予測結果保存用のmatrix | |
xhat.m <- matrix(0, nrow = n, ncol = dim) | |
# 補正結果保存用のmatrix | |
xhat <- matrix(0, nrow = n, ncol = dim) | |
# F | |
FM <- matrix(c(1, 0, 1, 1), ncol = 2) | |
# 結果保存用の array | |
# P[, , k] K[, , k] に 時刻 k 時点の計算結果を保存 | |
P <- array(c(0.0, 0.1, 0, 0.1), dim = c(dim, dim, n)) | |
K <- array(0, dim = c(dim, dim, n)) | |
# 誤差 | |
Q <- diag(0.01, nrow = dim, ncol = dim) | |
R <- diag(0.01, nrow = dim, ncol = dim) | |
# 単位行列 | |
I <- diag(1, ncol = dim, nrow = dim) | |
for (k in seq(2, n)) { | |
# predict | |
xhat.m[k, ] <- FM %*% xhat[k-1, ] | |
P.m <- FM %*% P[, , k-1] %*% t(FM) + Q | |
# update | |
S <- R + P.m | |
K[, , k] <- P.m %*% solve(S) | |
xhat[k, ] <- xhat.m[k, ] + K[, , k] %*% (observed[k, ] - xhat.m[k, ]) | |
P[, , k] <- (I - K[, , k]) * P.m | |
} | |
library(animation) | |
library(ggplot2) | |
library(gridExtra) | |
time = seq(1, n) | |
d <- as.data.frame(cbind(time, actual, observed, xhat, xhat.m)) | |
colnames(d) <- c('time', 'actual.x', 'actual.v', | |
'observed.x', 'observed.v', | |
'fitted.x', 'fitted.v', | |
'predict.x', 'predict.v') | |
saveGIF({ | |
disp.m <- 10 | |
frames <- 4 | |
xdiff <- diff(x) / frames | |
vdiff <- diff(v) / frames | |
xbase <- x[1:length(x) - 1] | |
xmax <- as.vector(rbind(xbase, xbase + xdiff, xbase + 2 * xdiff, xbase + 3 * xdiff)) | |
vbase <- v[1:length(v) - 1] | |
vmax <- as.vector(rbind(vbase, vbase + vdiff, vbase + 2 * vdiff, vbase + 3 * vdiff)) | |
for (i in seq(3, n-2, by = 1)) { | |
current <- head(d, i) | |
current <- tail(current, disp.m) | |
current.n <- nrow(current) | |
prev <- head(current, current.n - 1) | |
prev.x <- prev[nrow(prev), 'fitted.x'] | |
prev.v <- prev[nrow(prev), 'fitted.v'] | |
latest <- tail(current, 1) | |
p <- P[, , i] | |
k <- K[, , i] | |
xmin <- current[1, 'actual.x'] | |
ymin <- current[1, 'actual.v'] | |
for (j in seq(1, frames)) { | |
# 1: actual | |
p1 <- ggplot() + | |
geom_path(data = d, aes(x = actual.x, y = actual.v), colour = 'red') + | |
geom_point(data = prev, aes(x = observed.x, y = observed.v)) + | |
geom_path(data = prev, aes(x = fitted.x, y = fitted.v), colour = 'blue') + | |
xlim(xmin, xmax[i*4+j]) + ylim(ymin, vmax[i*4+j]) + | |
xlab('x') + ylab('v') | |
# 2: predict | |
if (j >= 2) { | |
x <- latest$predict.x | |
v <- latest$predict.v | |
p1 <- p1 + | |
annotate(geom = 'segment', x = prev.x, y = prev.v, xend = x, yend = v, | |
colour = 'blue', linetype = 'dashed') + | |
annotate(geom = 'text', x = x, y = v, | |
label = 'Predict', colour = 'blue', hjust = -0.2) | |
} | |
# 3: observation | |
if (j >= 3) { | |
x = latest$observed.x | |
v = latest$observed.v | |
p1 <- p1 + | |
annotate(geom = 'point', x = x, y = v, size = 5) + | |
annotate(geom = 'text', x = x, y = v, | |
label = 'Observed', colour = 'black', hjust = -0.2) | |
} | |
# 4: update | |
if (j >= 4) { | |
x <- latest$fitted.x | |
v <- latest$fitted.v | |
p1 <- p1 + | |
annotate(geom = 'segment', x = prev.x, y = prev.v, xend = x, yend = v, | |
colour = 'blue') + | |
annotate(geom = 'text', x = x, y = v, | |
label = 'Updated', colour = 'blue', hjust = -0.2) | |
} | |
# plot P and kalman gain | |
xunit <- (xmax[i*4+j] - xmin) / 20 | |
yunit <- (vmax[i*4+j] - ymin) / 20 | |
unit <- matrix(c(0, 1, 1, 0, 0, | |
0, 0, 1, 1, 0), ncol = 2) | |
dp <- det(p) | |
dk <- det(k) | |
pk <- data.frame(cbind(unit %*% p / sqrt(dp), | |
unit %*% k / sqrt(dk))) | |
colnames(pk) <- c('px', 'py', 'kx', 'ky') | |
pk$px <- pk$px * xunit + xmin + xunit * 3 | |
pk$py <- pk$py * yunit + ymin + yunit * 16 | |
pk$kx <- pk$kx * xunit + xmin + xunit * 7 | |
pk$ky <- pk$ky * yunit + ymin + yunit * 16 | |
p1 <- p1 + | |
geom_path(data = pk, aes(x = px, y = py), colour = 'red') + | |
annotate(geom = 'text', x = xmin + xunit * 3, | |
y = ymin + yunit * 15, colour = 'red', | |
label = paste0('det(P) = ', signif(dp, digits = 3))) + | |
geom_path(data = pk, aes(x = kx, y = ky), colour = 'green') + | |
annotate(geom = 'text', x = xmin + xunit * 8, | |
y = ymin + yunit * 15, colour = 'green', | |
label = paste0('det(K) = ', signif(dk, digits = 3))) | |
print(p1) | |
} | |
} | |
}, interval = 0.3, movie.name = "kalmanfilter04_01.gif", | |
ani.width = 600, ani.height = 600) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment