Skip to content

Instantly share code, notes, and snippets.

@sinhrks
Created November 4, 2014 15:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sinhrks/846bb4c11a87c44f9f45 to your computer and use it in GitHub Desktop.
Save sinhrks/846bb4c11a87c44f9f45 to your computer and use it in GitHub Desktop.
KalmanFilter (multivariate) with ggplot2 animation
set.seed(1)
# 観測系列のサンプルサイズ
n <- 100
# 真の値
x <- c(rep(0, n / 4), seq(0, 10, length.out = n / 4),
rep(10, n / 4), seq(10, 0, length.out = n / 4))
y <- c(seq(0, 10, length.out = n / 4), rep(10, n / 4),
seq(10, 0, length.out = n / 4), rep(0, n / 4))
actual <- cbind(x, y)
# 観測系列の次元
dim <- ncol(actual)
# 観測される値 (誤差は標準偏差0.1の正規分布とする)
observed <- actual + matrix(rnorm(n * dim, sd = 0.1), ncol = dim)
# 結果保存用のmatrix
xhat <- matrix(0, nrow = n, ncol = dim)
# 結果保存用の array
# P[, , k] K[, , k] に 時刻 k 時点の計算結果を保存
P <- array(c(0.0, 0.1, 0, 0.1), dim = c(dim, dim, n))
K <- array(0, dim = c(dim, dim, n))
# 誤差
Q <- diag(0.01, nrow = dim, ncol = dim)
R <- diag(0.01, nrow = dim, ncol = dim)
# 単位行列
I <- diag(1, ncol = dim, nrow = dim)
for (k in seq(2, n)) {
# predict
xhat.m <- xhat[k-1, ]
P.m <- P[, , k-1] + Q
# update
S <- R + P.m
K[, , k] <- P.m %*% solve(S)
xhat[k, ] <- xhat.m + K[, , k] %*% (observed[k, ] - xhat.m)
P[, , k] <- (I - K[, , k]) * P.m
}
library(animation)
library(ggplot2)
library(gridExtra)
time = seq(1, n)
d <- as.data.frame(cbind(time, actual, observed, xhat))
colnames(d) <- c('time', 'actual.x', 'actual.y',
'observed.x', 'observed.y',
'fitted.x', 'fitted.y')
klmplot <- function(data, axis) {
obs <- paste0('observed.', axis)
act <- paste0('actual.', axis)
fit <- paste0('fitted.', axis)
t <- nrow(data)
p <- ggplot(data = data, aes(x = time)) +
geom_point(aes_string(y = obs)) +
geom_path(aes_string(y = fit), colour = 'blue') +
annotate(geom = 'text', x = t, y = data[[fit]][t] - 0.5,
label = 'Fitted value', colour = 'blue', hjust = 1) +
geom_path(aes_string(y = act), colour = 'red') +
annotate(geom = 'text', x = t, y = data[[act]][t] + 0.5,
label = 'True value', colour = 'red', hjust = 1) +
ylim(-1, 11) +
scale_x_continuous(breaks = seq(0, n, by = 20)) +
xlab('time') + ylab(axis)
p
}
saveGIF({
for (i in seq(2, n, by = 2)) {
tmp <- head(d, i)
p1 <- ggplot(data = tmp) +
geom_point(aes(x = observed.x, y = observed.y)) +
geom_path(aes(x = actual.x, y = actual.y), colour = 'red') +
geom_path(aes(x = fitted.x, y = fitted.y), colour = 'blue') +
xlim(-1, 11) + ylim(-1, 11) +
xlab('x') + ylab('y')
p2 <- klmplot(tmp, 'x')
p3 <- klmplot(tmp, 'y')
grid.arrange(p1, p2, p3, ncol = 1, heights = c(3, 1, 1))
}
}, interval = 0.2, movie.name = "kalmanfilter03_01.gif",
ani.width = 600, ani.height = 600)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment