Created
January 30, 2024 19:06
MCMC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pymc as pm | |
import arviz as az | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.stats import gamma | |
# 観測データ | |
x = daily_sales.values # 観測データの例 | |
with pm.Model() as model: | |
mu = pm.Gamma('mu', alpha=2, beta=1) #事前分布 | |
obs = pm.Poisson('obs', mu=mu, observed=x) #尤度 | |
with model: | |
trace = pm.sample(3000, return_inferencedata=False) #定義されたモデルを使って、観測データxからサンプリングされるmuの値を3000個算出している | |
idata = pm.to_inference_data(trace) #複数のchainを並行してサンプリング | |
az.plot_trace(idata) #複数のchainでのmuの分布 | |
print(az.summary(idata)) #統計値の確認 | |
az.plot_posterior(idata) #muの変動の様子 | |
with model: #事後予測分布(事後分布を使って実際のデータの予測) | |
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) #このidataからサンプルされたパラメータ(この場合はmu)を使って、新しい観測データのサンプルを生成 | |
idata_ppc = pm.to_inference_data(posterior_predictive=ppc) | |
# 記述統計量(平均・分散)によるチェック | |
ppc_samples = ppc['obs'] | |
ppc_samples = ppc_samples.reshape(-1, 50) | |
ppc_mean = ppc_samples.mean(axis=1) | |
ppc_var = ppc_samples.var(axis=1) | |
print('PPC_Mean = {0}'.format(ppc_mean)) | |
print('PPC_Var = {0}'.format(ppc_var)) | |
# 分布によるチェック | |
az.plot_ppc(idata_ppc, kind='kde', num_pp_samples=50, figsize=(12, 4)); | |
az.plot_ppc(idata_ppc, kind='kde', num_pp_samples=3000, figsize=(12, 4)); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment