Skip to content

Instantly share code, notes, and snippets.

@smzn
Created January 30, 2024 19:06
MCMC
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gamma
# 観測データ
x = daily_sales.values # 観測データの例
with pm.Model() as model:
mu = pm.Gamma('mu', alpha=2, beta=1) #事前分布
obs = pm.Poisson('obs', mu=mu, observed=x) #尤度
with model:
trace = pm.sample(3000, return_inferencedata=False) #定義されたモデルを使って、観測データxからサンプリングされるmuの値を3000個算出している
idata = pm.to_inference_data(trace) #複数のchainを並行してサンプリング
az.plot_trace(idata) #複数のchainでのmuの分布
print(az.summary(idata)) #統計値の確認
az.plot_posterior(idata) #muの変動の様子
with model: #事後予測分布(事後分布を使って実際のデータの予測)
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) #このidataからサンプルされたパラメータ(この場合はmu)を使って、新しい観測データのサンプルを生成
idata_ppc = pm.to_inference_data(posterior_predictive=ppc)
# 記述統計量(平均・分散)によるチェック
ppc_samples = ppc['obs']
ppc_samples = ppc_samples.reshape(-1, 50)
ppc_mean = ppc_samples.mean(axis=1)
ppc_var = ppc_samples.var(axis=1)
print('PPC_Mean = {0}'.format(ppc_mean))
print('PPC_Var = {0}'.format(ppc_var))
# 分布によるチェック
az.plot_ppc(idata_ppc, kind='kde', num_pp_samples=50, figsize=(12, 4));
az.plot_ppc(idata_ppc, kind='kde', num_pp_samples=3000, figsize=(12, 4));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment