Skip to content

Instantly share code, notes, and snippets.

@smzn
Created January 30, 2024 19:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save smzn/5d00799d8c28f603e2c1f97bec6a0124 to your computer and use it in GitHub Desktop.
Save smzn/5d00799d8c28f603e2c1f97bec6a0124 to your computer and use it in GitHub Desktop.
MCMC
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gamma
# 観測データ
x = daily_sales.values # 観測データの例
with pm.Model() as model:
mu = pm.Gamma('mu', alpha=2, beta=1) #事前分布
obs = pm.Poisson('obs', mu=mu, observed=x) #尤度
with model:
trace = pm.sample(3000, return_inferencedata=False) #定義されたモデルを使って、観測データxからサンプリングされるmuの値を3000個算出している
idata = pm.to_inference_data(trace) #複数のchainを並行してサンプリング
az.plot_trace(idata) #複数のchainでのmuの分布
print(az.summary(idata)) #統計値の確認
az.plot_posterior(idata) #muの変動の様子
with model: #事後予測分布(事後分布を使って実際のデータの予測)
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) #このidataからサンプルされたパラメータ(この場合はmu)を使って、新しい観測データのサンプルを生成
idata_ppc = pm.to_inference_data(posterior_predictive=ppc)
# 記述統計量(平均・分散)によるチェック
ppc_samples = ppc['obs']
ppc_samples = ppc_samples.reshape(-1, 50)
ppc_mean = ppc_samples.mean(axis=1)
ppc_var = ppc_samples.var(axis=1)
print('PPC_Mean = {0}'.format(ppc_mean))
print('PPC_Var = {0}'.format(ppc_var))
# 分布によるチェック
az.plot_ppc(idata_ppc, kind='kde', num_pp_samples=50, figsize=(12, 4));
az.plot_ppc(idata_ppc, kind='kde', num_pp_samples=3000, figsize=(12, 4));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment