Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Linear Kalman filter and animation
# multivariate normal Kalman filter
require(dplyr)
require(tidyr)
require(ggplot2)
require(animation)
# ARIMA(1,1) + 線形トレンド の乱数生成
N <- 50
phi1 <- .5
theta1 <- .2
sigma <- 1
delta <- .5
set.seed(42)
y <- arima.sim(model=list(ar=phi1, ma=theta1),
n=N,
innov=rnorm(N) * sigma)
y <- y + delta * 1:N
# グラフで確認
ts.plot(y)
# 毎期のフィルタリングを行う関数
Kf.filter.linear <- function(Z, G, H, Q, R, y, xhat, P){
# 引数:
# x[t+1] = Zx[t] + Gv[t]
# y[t] = Hx[t] + w[t]
# v[t] ~ N(0, Q)
# w[t] ~ N(0, R)
# u = u[t-1]
# y = observation value at [t]
# xhat = prior state estimates at [t-1]
# P = posteriror state variance at [t-1]
# 返り値:
# xpri = xpri[t+1], xpost = xpost[t],
# Ppri = Ppri[t+1], Ppost = Ppost[t],
# K = Kalman gain at t
y <- matrix(y, ncol=1)
xhat <- matrix(xhat, ncol=1)
# 観測値に欠損がない場合に更新
# innovataion term
v <- y - H %*% xhat
Vv <- H %*% P %*% t(H) + R
# Kalman gain
K <- P %*% t(H) %*% solve(Vv)
# filtering
if(!any(is.na(v))){
xpost <- xhat + K %*% v
Ppost <- P - K %*% Vv %*% t(K) # variance
} else{
# 欠損のある場合
xpost <- xhat
Ppost <- P
}
# one-step-ahead
xpred <- Z %*% xpost
Ppred <- Z %*% Ppost %*% t(Z) + t(G) %*% Q %*% t(G)
return(list(xpost=xpost, xpred=xpred,
Ppost=Ppost, Ppred=Ppred,
K=K))
}
# matrix 型にする
y <- matrix(y, nrow=N, ncol=1)
# 欠測バージョン
# y[20:29] <- NA
#### モデルパラメータ ####
# system: x[t+1] = Ax[t] + Cv[t]
# obs: y[t] = Bx[t] + e[t]
# v[t] ~ NID(0,Q), e[t] ~ NID(0, R)
# 本来ならここは推定する必要がある
A <- matrix(c(.5, 1, 1, 0,
0, 0, 0, 0,
0, 0, 1, 1,
0, 0, 0, 1), nrow=4, ncol=4, byrow = T)
B <- matrix(c(1, 0, 0, 0), nrow=1, ncol=4, byrow = T)
C <- matrix(c(1, 0, 0, 0,
.5, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0), nrow=4, ncol=4, byrow = T)
Q <- diag(1, 4)
R <- diag(0, 1)
# for(i in 1:N){
# temp <- Kf.filter.linear(Z = A, G = C, H = B,
# Q = Q, R = R,
# y = y[i, ],
# xhat = xpri[i, ],
# P = P[[i]])
# xpost[i,] <- temp$xpost
# if(i < N)
# xpri[i+1,] <- temp$xpred
# P[[i+1]] <- temp$Ppred
# K[[i]] <- temp$K
# }
kalman.l <- function(A, C, B, Q, R, y, xini, Pini, N=NULL){
# 規定の期間までフィルタリング or 予測
if(is.null(N))
N <- nrow(y)
xpri <- matrix(0, nrow=N, ncol=nrow(A))
xpost <- matrix(0, nrow=N, ncol=nrow(A))
P <- list(matrix(0, nrow=nrow(A), ncol=nrow(A)))
K <- list(matrix(0, nrow=nrow(A), ncol=nrow(A)))
# 状態推定の初期値
xpri[1, ] <- xini
P[[1]] <- Pini
for(i in 1:N){
temp <- Kf.filter.linear(Z = A, G = C, H = B,
Q = Q, R = R,
y = y[i, ],
xhat = xpri[i, ],
P = P[[i]])
xpost[i,] <- temp$xpost
if(i < N)
xpri[i+1,] <- temp$xpred
P[[i+1]] <- temp$Ppred
K[[i]] <- temp$K
}
return(list(xpost=xpost, xpri=xpri, P=P, K=K))
}
# 実行
result <- kalman.l(A, C, B, Q, R, y, c(0, 0, delta, delta), diag(1.5, 4))
# グラフ用のデータ作成
df <- data.frame(y) %>% mutate(t=1:N)
colnames(df)[1] <- "raw"
df$xpri <- result$xpri[,1]
df$xpost <- result$xpost[,1]
### 描画用関数 ###
drawKalman <- function(df, ci=.95){
test <- list()
t_max <- max(df$t)
# True = one-step-ahead predict,
# False = filtering
kalmanstep <- T
g.raw <- ggplot(df) +
geom_point(aes(x=t, y=raw), color="red4", size=1)
for(now in 1:t_max){
y.pred <- df$xpost[now]
# forecasting
if(now < t_max){
y.temp <- y
y.temp[(now+1):t_max,] <- NA
temp <- kalman.l(A, C, B, Q, R, y.temp,
c(0, 0, delta, delta), diag(1.5, 4))
df.for <- data.frame(
t=1:nrow(y.temp),
xpost = temp$xpost[, 1],
xlow = qnorm(p=(1-ci)/2,
mean=temp$xpost[, 1],
sd=unlist(lapply(temp$P[-1], function(x) return(sqrt(x[1,1]))))
),
xup = qnorm(p=(1-ci)/2,
mean=temp$xpost[, 1],
sd=unlist(lapply(temp$P[-1], function(x) return(sqrt(x[1,1])))),
lower.tail=F
)
) %>% filter(t > now)
test[[now]] <- df.for
}
for( kalmanstep in c(T, F)){
g <- g.raw + geom_step(data=dplyr::select(df[1:(now - 1),], -raw) %>%
gather(key=type, value=state, -t) %>%
arrange(t),
aes(x=t, y=state), color="blue", size=.6)
# predict and filtering at round t
g <- g + geom_point(data=df[1:(now-1), ],
aes(x=t, y=xpri), size=2, color="green")
g <- g + geom_point(data=df[1:(now-1), ],
aes(x=t, y=xpost), size=2, color="blue")
if(kalmanstep ==T){
g <- g + geom_point(data=df[now, ],
aes(x=t, y=xpri), size=2, color="green") +
annotate("text", x=now, y=y.pred * .8, hjust=0,
label="prior estimation")
} else {
g <- g + geom_point(data=df[now, ],
aes(x=t, y=xpost), size=2, color="blue") +
annotate("text", x=now, y=y.pred * .8, hjust=0,
label="posterior estimation")
}
if(now < t_max){
# forcasting
g <- g +
geom_line(data=df.for, aes(x=t, y=xpost),
color="blue", linetype="dashed", alpha=.5, size=1) +
# confidence intervals
geom_ribbon(data=df.for, aes(x=t, ymin=xlow, ymax=xup), alpha=.2, fill="blue") +
geom_line(data=df.for, aes(x=t, y=xlow), color="black", linetype="dashed") +
geom_line(data=df.for, aes(x=t, y=xup), color="black", linetype="dashed")
# highlighting forecasting term
g <- g + geom_rect(aes(xmin=now + 1, xmax=t_max,
ymin=-Inf, ymax=Inf),
fill="grey", alpha=.02)
}
# legend
g <- g + theme_bw() + guides(col=guide_legend()) +
labs(y="y", title="Kalman filter") +
coord_cartesian(ylim=c(min(y, na.rm = T), max(y, na.rm = T))) +
theme(plot.title=element_text(hjust=.5)) +
annotate("text", x=t_max, y=min(y, na.rm = T), hjust=1, vjust=1,
label="http://ill-identified.hatenablog.com/")
print(g)
}
}
return(test)
}
# gif 作成
animation::saveGIF(
expr = drawKalman(df),
movie.name = "kalman.gif",
interval = .2,
ani.width=400, ani.height=250)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment