Last active
December 23, 2018 14:16
-
-
Save unaoya/f62be3de20f971f2d5963b978abd748f to your computer and use it in GitHub Desktop.
検索量を用いた状態空間モデルによる売上予測
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data { | |
int N; //学習期間の長さ | |
int N_pred; //予測期間の長さ | |
vector[N] Y; //販売台数データ | |
} | |
parameters { | |
vector[N] alpha; //状態のトレンド成分 | |
vector[N] season; //状態の季節成分 | |
real<lower=0> s_Y; //観測誤差の分散 | |
real<lower=0> s_a; //トレンド成分の分散 | |
real<lower=0> s_season; //季節成分の分散 | |
} | |
transformed parameters { | |
vector[N] y_mean; | |
y_mean = alpha + season; | |
} | |
model { | |
alpha[3:N] ~ normal(2*alpha[2:(N-1)] - alpha[1:(N-2)], s_a); //状態モデル | |
for(t in 12:N){ | |
season[t] ~ normal(-sum(season[(t-11):(t-1)]), s_season); //季節成分 | |
} | |
Y ~ normal(y_mean, s_Y); //観測モデル | |
} | |
generated quantities { | |
vector[N+N_pred] alpha_all; | |
vector[N+N_pred] season_all; | |
vector[N+N_pred] Y_all; | |
alpha_all[1:N] = alpha; | |
season_all[1:N] = season; | |
Y_all[1:N] = y_mean; | |
for (t in 1:N_pred) { | |
season_all[N+t] = normal_rng(-sum(season_all[(N+t-11):(N+t-1)]), s_season); | |
alpha_all[N+t] = normal_rng(2*alpha_all[N+t-1] - alpha_all[N+t-2], s_a); | |
Y_all[N+t] = normal_rng(alpha_all[N+t]+season_all[t], s_Y); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(rstan) | |
library(ggmcmc) | |
library(dplyr) | |
library(bayesplot) | |
# データの読み込み | |
df_sail <- read.csv('data/car_sail.csv') | |
df_search <- read.csv('data/car_search.csv') | |
# 販売量データは2014/1から2018/6まで(長さ54) | |
# 検索データは2013/7から2018/7まで(長さ61) | |
(sail <- df_sail$ノート) | |
(search <- df_search$ノート) | |
# 学習に2017/6まで、テストに2017/7から2018/6まで | |
N_test <- 12 | |
data <- list(Y = sail[1:(length(sail)-N_test)], | |
X = search[7:(length(sail)-N_test+6)], | |
N = length(sail) - N_test, | |
N_pred = N_test) | |
# ベースラインモデル | |
fit_base <- stan(file = "./model/base.stan", | |
data = data, | |
iter = 10000) | |
# 推定結果 | |
mcmc_rhat(rhat(fit_base)) | |
# 結果の図示 | |
ggs(fit_base) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(`2.5%` = quantile(value, probs=.025), | |
`10%` = quantile(value, probs=.1), | |
`50%` = quantile(value, probs=.5), | |
`90%` = quantile(value, probs=.9), | |
`97.5%`= quantile(value, probs=.975)) %>% | |
mutate(sail = sail) %>% | |
ggplot() + | |
geom_ribbon(mapping=aes(x=x, ymin=`2.5%`, ymax=`97.5%`), alpha=1/6) + | |
geom_ribbon(mapping=aes(x=x, ymin=`10%`, ymax=`90%`), alpha=2/6) + | |
geom_line(mapping=aes(x=x, y=`50%`)) + | |
geom_line(aes(x=x, y=sail), shape=1, size=2) + | |
labs(x='月', y='販売台数') + | |
ggtitle ("ベースラインモデル") + | |
theme_gray (base_family = "HiraKakuPro-W3") | |
# 予測誤差 | |
# 予測値の平均と実測値の二乗誤差 | |
ggs(fit_base) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(mean = mean(value)) %>% | |
mutate(sail = sail) %>% | |
mutate(sqe=(sail-mean)^2) %>% | |
select(sqe) %>% | |
slice((length(sail)-11):length(sail)) %>% | |
summarize(sum(sqe)) | |
# 提案モデル | |
fit_uni0 <- stan(file = "./model/uni0.stan", | |
data = data, | |
iter = 10000) | |
# 推定結果 | |
mcmc_rhat(rhat(fit_uni0)) | |
# 結果の図示 | |
ggs(fit_uni0) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(`2.5%` = quantile(value, probs=.025), | |
`10%` = quantile(value, probs=.1), | |
`50%` = quantile(value, probs=.5), | |
`90%` = quantile(value, probs=.9), | |
`97.5%`= quantile(value, probs=.975)) %>% | |
mutate(sail = sail) %>% | |
ggplot() + | |
geom_ribbon(mapping=aes(x=x, ymin=`2.5%`, ymax=`97.5%`), alpha=1/6) + | |
geom_ribbon(mapping=aes(x=x, ymin=`10%`, ymax=`90%`), alpha=2/6) + | |
geom_line(mapping=aes(x=x, y=`50%`)) + | |
geom_line(aes(x=x, y=sail), shape=1, size=2) + | |
labs(x='月', y='販売台数') + | |
ggtitle ("同期モデル") + | |
theme_gray (base_family = "HiraKakuPro-W3") | |
# 予測誤差 | |
# 予測値の平均と実測値の二乗誤差 | |
ggs(fit_uni0) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(mean = mean(value)) %>% | |
mutate(sail = sail) %>% | |
mutate(sqe=(sail-mean)^2) %>% | |
select(sqe) %>% | |
slice((length(sail)-11):length(sail)) %>% | |
summarize(sum(sqe)) | |
# 一期前のみ | |
fit_uni1 <- stan(file = "./model/uni1.stan", | |
data = data, | |
iter = 10000) | |
# 推定結果 | |
mcmc_rhat(rhat(fit_uni1)) | |
# 結果の図示 | |
ggs(fit_uni1) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(`2.5%` = quantile(value, probs=.025), | |
`10%` = quantile(value, probs=.1), | |
`50%` = quantile(value, probs=.5), | |
`90%` = quantile(value, probs=.9), | |
`97.5%`= quantile(value, probs=.975)) %>% | |
mutate(sail = sail) %>% | |
ggplot() + | |
geom_ribbon(mapping=aes(x=x, ymin=`2.5%`, ymax=`97.5%`), alpha=1/6) + | |
geom_ribbon(mapping=aes(x=x, ymin=`10%`, ymax=`90%`), alpha=2/6) + | |
geom_line(mapping=aes(x=x, y=`50%`)) + | |
geom_line(aes(x=x, y=sail), shape=1, size=2) + | |
labs(x='月', y='販売台数') + | |
ggtitle ("一期前モデル") + | |
theme_gray (base_family = "HiraKakuPro-W3") | |
# 予測誤差 | |
# 予測値の平均と実測値の二乗誤差 | |
ggs(fit_uni1) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(mean = mean(value)) %>% | |
mutate(sail = sail) %>% | |
mutate(sqe=(sail-mean)^2) %>% | |
select(sqe) %>% | |
slice((length(sail)-11):length(sail)) %>% | |
summarize(sum(sqe)) | |
# 二期前のみ | |
fit_uni2 <- stan(file = "./model/uni2.stan", | |
data = data, | |
iter = 10000) | |
# 推定結果 | |
mcmc_rhat(rhat(fit_uni2)) | |
# 結果の図示 | |
ggs(fit_uni2) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(`2.5%` = quantile(value, probs=.025), | |
`10%` = quantile(value, probs=.1), | |
`50%` = quantile(value, probs=.5), | |
`90%` = quantile(value, probs=.9), | |
`97.5%`= quantile(value, probs=.975)) %>% | |
mutate(sail = sail) %>% | |
ggplot() + | |
geom_ribbon(mapping=aes(x=x, ymin=`2.5%`, ymax=`97.5%`), alpha=1/6) + | |
geom_ribbon(mapping=aes(x=x, ymin=`10%`, ymax=`90%`), alpha=2/6) + | |
geom_line(mapping=aes(x=x, y=`50%`)) + | |
geom_line(aes(x=x, y=sail), shape=1, size=2) + | |
labs(x='月', y='販売台数') + | |
ggtitle ("二期前モデル") + | |
theme_gray (base_family = "HiraKakuPro-W3") | |
# 予測誤差 | |
# 予測値の平均と実測値の二乗誤差 | |
ggs(fit_uni2) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(mean = mean(value)) %>% | |
mutate(sail = sail) %>% | |
mutate(sqe=(sail-mean)^2) %>% | |
select(sqe) %>% | |
slice((length(sail)-11):length(sail)) %>% | |
summarize(sum(sqe)) | |
# 二期前まで全て使う | |
fit_mult <- stan(file = "./model/mult.stan", | |
data = data, | |
iter = 10000) | |
# 推定結果 | |
mcmc_rhat(rhat(fit_mult)) | |
# 結果の図示 | |
ggs(fit_mult) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(`2.5%` = quantile(value, probs=.025), | |
`10%` = quantile(value, probs=.1), | |
`50%` = quantile(value, probs=.5), | |
`90%` = quantile(value, probs=.9), | |
`97.5%`= quantile(value, probs=.975)) %>% | |
mutate(sail = sail) %>% | |
ggplot() + | |
geom_ribbon(mapping=aes(x=x, ymin=`2.5%`, ymax=`97.5%`), alpha=1/6) + | |
geom_ribbon(mapping=aes(x=x, ymin=`10%`, ymax=`90%`), alpha=2/6) + | |
geom_line(mapping=aes(x=x, y=`50%`)) + | |
geom_line(aes(x=x, y=sail), shape=1, size=2) + | |
labs(x='月', y='販売台数') + | |
ggtitle ("複数期モデル") + | |
theme_gray (base_family = "HiraKakuPro-W3") | |
# 予測誤差 | |
# 予測値の平均と実測値の二乗誤差 | |
ggs(fit_mult) %>% | |
filter(grepl('^Y_all\\[\\d+\\]$', Parameter)) %>% | |
tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]', convert=TRUE) %>% | |
group_by(Parameter, x) %>% | |
summarize(mean = mean(value)) %>% | |
mutate(sail = sail) %>% | |
mutate(sqe=(sail-mean)^2) %>% | |
select(sqe) %>% | |
slice((length(sail)-11):length(sail)) %>% | |
summarize(sum(sqe)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data { | |
int N; //学習期間の長さ | |
int N_pred; //予測期間の長さ | |
vector[N] Y; //販売台数データ | |
vector[N] X; //検索量データ | |
} | |
parameters { | |
vector[N] alpha_X; //検索量の状態トレンド成分 | |
vector[N] season_X; //検索量の状態季節成分 | |
vector[N] alpha_Y; //販売台数の状態トレンド成分 | |
vector[N] season_Y; //販売台数の状態季節成分 | |
real<lower=0> s_X; //検索量の観測誤差の分散 | |
real<lower=0> s_Y; //販売台数の観測誤差の分散 | |
real<lower=0> s_a_X; //検索量の状態トレンドの分散 | |
real<lower=0> s_season_X; //検索量の季節成分の分散 | |
real<lower=0> s_a_Y; //販売台数の状態トレンドの分散 | |
real<lower=0> s_season_Y; //販売台数の季節成分の分散 | |
vector[3] b; //販売台数に対する検索量のトレンドの重み | |
} | |
transformed parameters { | |
vector[N] x_mean; | |
vector[N] y_mean; | |
x_mean = alpha_X + season_X; | |
y_mean = alpha_Y + season_Y; | |
} | |
model { | |
alpha_X[3:N] ~ normal(2*alpha_X[2:(N-1)] - alpha_X[1:(N-2)], s_a_X); //検索量トレンドの状態モデル | |
for(t in 12:N){ | |
season_X[t] ~ normal(-sum(season_X[(t-11):(t-1)]), s_season_X); //検索量季節成分 | |
} | |
X ~ normal(x_mean, s_X); //検索量の観測モデル | |
alpha_Y[3:N] ~ normal(2*alpha_Y[2:(N-1)] - alpha_Y[1:(N-2)], s_a_Y); //販売台数トレンドの状態モデル | |
for(t in 12:N){ | |
season_Y[t] ~ normal(-sum(season_Y[(t-11):(t-1)]), s_season_Y); //販売台数季節成分 | |
} | |
// 以下が違う | |
Y[1] ~ normal(y_mean[1], s_Y); | |
Y[2] ~ normal(y_mean[2], s_Y); | |
Y[3:N] ~ normal(y_mean[3:N] + b[1] * alpha_X[3:N] + b[2] * alpha_X[2:(N-1)] + b[3] * alpha_X[1:(N-2)], s_Y); //販売台数の観測モデル | |
} | |
generated quantities { | |
vector[N+N_pred] alpha_X_all; | |
vector[N+N_pred] season_X_all; | |
vector[N+N_pred] alpha_Y_all; | |
vector[N+N_pred] season_Y_all; | |
vector[N+N_pred] X_all; | |
vector[N+N_pred] Y_all; | |
alpha_X_all[1:N] = alpha_X; | |
season_X_all[1:N] = season_X; | |
alpha_Y_all[1:N] = alpha_Y; | |
season_Y_all[1:N] = season_Y; | |
X_all[1:N] = x_mean; | |
Y_all[1:N] = y_mean; | |
for (t in 1:N_pred) { | |
season_X_all[N+t] = normal_rng(-sum(season_X_all[(N+t-11):(N+t-1)]), s_season_X); | |
alpha_X_all[N+t] = normal_rng(2*alpha_X_all[N+t-1] - alpha_X_all[N+t-2], s_a_X); | |
X_all[N+t] = normal_rng(alpha_X_all[N+t]+season_X_all[t], s_X); | |
} | |
for (t in 1:N_pred) { | |
season_Y_all[N+t] = normal_rng(-sum(season_Y_all[(N+t-11):(N+t-1)]), s_season_Y); | |
alpha_Y_all[N+t] = normal_rng(2*alpha_Y_all[N+t-1] - alpha_Y_all[N+t-2], s_a_Y); | |
Y_all[t+N] = normal_rng(alpha_Y_all[N+t]+season_Y_all[t]+b[1]*X_all[N+t]+b[2]*X_all[N+t-1]+b[3]*X_all[N+t-2], s_Y); //ここも違う | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data { | |
int N; //学習期間の長さ | |
int N_pred; //予測期間の長さ | |
vector[N] Y; //販売台数データ | |
vector[N] X; //検索量データ | |
} | |
parameters { | |
vector[N] alpha_X; //検索量の状態トレンド成分 | |
vector[N] season_X; //検索量の状態季節成分 | |
vector[N] alpha_Y; //販売台数の状態トレンド成分 | |
vector[N] season_Y; //販売台数の状態季節成分 | |
real<lower=0> s_X; //検索量の観測誤差の分散 | |
real<lower=0> s_Y; //販売台数の観測誤差の分散 | |
real<lower=0> s_a_X; //検索量の状態トレンドの分散 | |
real<lower=0> s_season_X; //検索量の季節成分の分散 | |
real<lower=0> s_a_Y; //販売台数の状態トレンドの分散 | |
real<lower=0> s_season_Y; //販売台数の季節成分の分散 | |
real b; //販売台数に対する検索量のトレンドの重み | |
} | |
transformed parameters { | |
vector[N] x_mean; | |
vector[N] y_mean; | |
x_mean = alpha_X + season_X; | |
y_mean = alpha_Y + season_Y; | |
} | |
model { | |
alpha_X[3:N] ~ normal(2*alpha_X[2:(N-1)] - alpha_X[1:(N-2)], s_a_X); //検索量トレンドの状態モデル | |
for(t in 12:N){ | |
season_X[t] ~ normal(-sum(season_X[(t-11):(t-1)]), s_season_X); //検索量季節成分 | |
} | |
X ~ normal(x_mean, s_X); //検索量の観測モデル | |
alpha_Y[3:N] ~ normal(2*alpha_Y[2:(N-1)] - alpha_Y[1:(N-2)], s_a_Y); //販売台数トレンドの状態モデル | |
for(t in 12:N){ | |
season_Y[t] ~ normal(-sum(season_Y[(t-11):(t-1)]), s_season_Y); //販売台数季節成分 | |
} | |
Y ~ normal(y_mean + b * alpha_X, s_Y); //販売台数の観測モデル | |
} | |
generated quantities { | |
vector[N+N_pred] alpha_X_all; | |
vector[N+N_pred] season_X_all; | |
vector[N+N_pred] alpha_Y_all; | |
vector[N+N_pred] season_Y_all; | |
vector[N+N_pred] X_all; | |
vector[N+N_pred] Y_all; | |
alpha_X_all[1:N] = alpha_X; | |
season_X_all[1:N] = season_X; | |
alpha_Y_all[1:N] = alpha_Y; | |
season_Y_all[1:N] = season_Y; | |
X_all[1:N] = x_mean; | |
Y_all[1:N] = y_mean; | |
for (t in 1:N_pred) { | |
season_X_all[N+t] = normal_rng(-sum(season_X_all[(N+t-11):(N+t-1)]), s_season_X); | |
alpha_X_all[N+t] = normal_rng(2*alpha_X_all[N+t-1] - alpha_X_all[N+t-2], s_a_X); | |
X_all[N+t] = normal_rng(alpha_X_all[N+t]+season_X_all[t], s_X); | |
} | |
for (t in 1:N_pred) { | |
season_Y_all[N+t] = normal_rng(-sum(season_Y_all[(N+t-11):(N+t-1)]), s_season_Y); | |
alpha_Y_all[N+t] = normal_rng(2*alpha_Y_all[N+t-1] - alpha_Y_all[N+t-2], s_a_Y); | |
Y_all[N+t] = normal_rng(alpha_Y_all[N+t]+season_Y_all[t]+b*X_all[N+t], s_Y); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data { | |
int N; //学習期間の長さ | |
int N_pred; //予測期間の長さ | |
vector[N] Y; //販売台数データ | |
vector[N] X; //検索量データ | |
} | |
parameters { | |
vector[N] alpha_X; //検索量の状態トレンド成分 | |
vector[N] season_X; //検索量の状態季節成分 | |
vector[N] alpha_Y; //販売台数の状態トレンド成分 | |
vector[N] season_Y; //販売台数の状態季節成分 | |
real<lower=0> s_X; //検索量の観測誤差の分散 | |
real<lower=0> s_Y; //販売台数の観測誤差の分散 | |
real<lower=0> s_a_X; //検索量の状態トレンドの分散 | |
real<lower=0> s_season_X; //検索量の季節成分の分散 | |
real<lower=0> s_a_Y; //販売台数の状態トレンドの分散 | |
real<lower=0> s_season_Y; //販売台数の季節成分の分散 | |
real b; //販売台数に対する検索量のトレンドの重み | |
} | |
transformed parameters { | |
vector[N] x_mean; | |
vector[N] y_mean; | |
x_mean = alpha_X + season_X; | |
y_mean = alpha_Y + season_Y; | |
} | |
model { | |
alpha_X[3:N] ~ normal(2*alpha_X[2:(N-1)] - alpha_X[1:(N-2)], s_a_X); //検索量トレンドの状態モデル | |
for(t in 12:N){ | |
season_X[t] ~ normal(-sum(season_X[(t-11):(t-1)]), s_season_X); //検索量季節成分 | |
} | |
X ~ normal(x_mean, s_X); //検索量の観測モデル | |
alpha_Y[3:N] ~ normal(2*alpha_Y[2:(N-1)] - alpha_Y[1:(N-2)], s_a_Y); //販売台数トレンドの状態モデル | |
for(t in 12:N){ | |
season_Y[t] ~ normal(-sum(season_Y[(t-11):(t-1)]), s_season_Y); //販売台数季節成分 | |
} | |
// 以下が違う | |
Y[1] ~ normal(y_mean[1], s_Y); | |
Y[2:N] ~ normal(y_mean[2:N] + b * alpha_X[1:(N-1)], s_Y); //販売台数の観測モデル | |
} | |
generated quantities { | |
vector[N+N_pred] alpha_X_all; | |
vector[N+N_pred] season_X_all; | |
vector[N+N_pred] alpha_Y_all; | |
vector[N+N_pred] season_Y_all; | |
vector[N+N_pred] X_all; | |
vector[N+N_pred] Y_all; | |
alpha_X_all[1:N] = alpha_X; | |
season_X_all[1:N] = season_X; | |
alpha_Y_all[1:N] = alpha_Y; | |
season_Y_all[1:N] = season_Y; | |
X_all[1:N] = x_mean; | |
Y_all[1:N] = y_mean; | |
for (t in 1:N_pred) { | |
season_X_all[N+t] = normal_rng(-sum(season_X_all[(N+t-11):(N+t-1)]), s_season_X); | |
alpha_X_all[N+t] = normal_rng(2*alpha_X_all[N+t-1] - alpha_X_all[N+t-2], s_a_X); | |
X_all[N+t] = normal_rng(alpha_X_all[N+t]+season_X_all[t], s_X); | |
} | |
for (t in 1:N_pred) { | |
season_Y_all[N+t] = normal_rng(-sum(season_Y_all[(N+t-11):(N+t-1)]), s_season_Y); | |
alpha_Y_all[N+t] = normal_rng(2*alpha_Y_all[N+t-1] - alpha_Y_all[N+t-2], s_a_Y); | |
Y_all[N+t] = normal_rng(alpha_Y_all[N+t]+season_Y_all[N+t]+b*X_all[N+t-1], s_Y); //ここも違う | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data { | |
int N; //学習期間の長さ | |
int N_pred; //予測期間の長さ | |
vector[N] Y; //販売台数データ | |
vector[N] X; //検索量データ | |
} | |
parameters { | |
vector[N] alpha_X; //検索量の状態トレンド成分 | |
vector[N] season_X; //検索量の状態季節成分 | |
vector[N] alpha_Y; //販売台数の状態トレンド成分 | |
vector[N] season_Y; //販売台数の状態季節成分 | |
real<lower=0> s_X; //検索量の観測誤差の分散 | |
real<lower=0> s_Y; //販売台数の観測誤差の分散 | |
real<lower=0> s_a_X; //検索量の状態トレンドの分散 | |
real<lower=0> s_season_X; //検索量の季節成分の分散 | |
real<lower=0> s_a_Y; //販売台数の状態トレンドの分散 | |
real<lower=0> s_season_Y; //販売台数の季節成分の分散 | |
real b; //販売台数に対する検索量のトレンドの重み | |
} | |
transformed parameters { | |
vector[N] x_mean; | |
vector[N] y_mean; | |
x_mean = alpha_X + season_X; | |
y_mean = alpha_Y + season_Y; | |
} | |
model { | |
alpha_X[3:N] ~ normal(2*alpha_X[2:(N-1)] - alpha_X[1:(N-2)], s_a_X); //検索量トレンドの状態モデル | |
for(t in 12:N){ | |
season_X[t] ~ normal(-sum(season_X[(t-11):(t-1)]), s_season_X); //検索量季節成分 | |
} | |
X ~ normal(x_mean, s_X); //検索量の観測モデル | |
alpha_Y[3:N] ~ normal(2*alpha_Y[2:(N-1)] - alpha_Y[1:(N-2)], s_a_Y); //販売台数トレンドの状態モデル | |
for(t in 12:N){ | |
season_Y[t] ~ normal(-sum(season_Y[(t-11):(t-1)]), s_season_Y); //販売台数季節成分 | |
} | |
// 以下が違う | |
Y[1] ~ normal(y_mean[1], s_Y); | |
Y[2] ~ normal(y_mean[2], s_Y); | |
Y[3:N] ~ normal(y_mean[3:N] + b * alpha_X[1:(N-2)], s_Y); //販売台数の観測モデル | |
} | |
generated quantities { | |
vector[N+N_pred] alpha_X_all; | |
vector[N+N_pred] season_X_all; | |
vector[N+N_pred] alpha_Y_all; | |
vector[N+N_pred] season_Y_all; | |
vector[N+N_pred] X_all; | |
vector[N+N_pred] Y_all; | |
alpha_X_all[1:N] = alpha_X; | |
season_X_all[1:N] = season_X; | |
alpha_Y_all[1:N] = alpha_Y; | |
season_Y_all[1:N] = season_Y; | |
X_all[1:N] = x_mean; | |
Y_all[1:N] = y_mean; | |
for (t in 1:N_pred) { | |
season_X_all[N+t] = normal_rng(-sum(season_X_all[(N+t-11):(N+t-1)]), s_season_X); | |
alpha_X_all[N+t] = normal_rng(2*alpha_X_all[N+t-1] - alpha_X_all[N+t-2], s_a_X); | |
X_all[N+t] = normal_rng(alpha_X_all[N+t]+season_X_all[t], s_X); | |
} | |
for (t in 1:N_pred) { | |
season_Y_all[N+t] = normal_rng(-sum(season_Y_all[(N+t-11):(N+t-1)]), s_season_Y); | |
alpha_Y_all[N+t] = normal_rng(2*alpha_Y_all[N+t-1] - alpha_Y_all[N+t-2], s_a_Y); | |
Y_all[t+N] = normal_rng(alpha_Y_all[N+t]+season_Y_all[t]+b*X_all[N+t-2], s_Y); //ここも違う | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment