Skip to content

Instantly share code, notes, and snippets.

@alphaville
Last active March 30, 2023 21:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alphaville/4b6b6cb931e591c7c7b6fb64f15fef52 to your computer and use it in GitHub Desktop.
Save alphaville/4b6b6cb931e591c7c7b6fb64f15fef52 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 30 20:41:07 2023
@author: Pantelis Sopasakis
"""
import numpy as np
import scipy.special as sp
import scipy.stats as ss
import pymc3 as pm
import arviz as az
import matplotlib.pyplot as plt
# %% Generate data from dynanimal system
# State update function
# x(t+1) = f(x(t); a),
# where a is an unknown parameter to be estimated
#
# True parameter value: a = 1
def f(x, a=1):
return a * np.sin(x)
# Output function, y = g(x)
def g(x):
return x + 1
# number of data points to be generated
N = 8
# Generate output data
y = np.zeros((N, 1))
x = 1
for i in range(N):
y[i] = g(x) + np.random.normal(0, 0.01)
x = f(x, a=0.7) + np.random.normal(0, 0.0001)
# %% MCMC
with pm.Model() as toy_model:
BoundedNormalDist = pm.Bound(pm.Normal, lower=1e-6)
x0 = BoundedNormalDist('x0', mu=0.8, sigma=2.0)
BoundedBeta = pm.Bound(pm.Beta, lower=1e-6)
a = BoundedBeta('alpha', mu=8, sigma=4)
x_rand = x0
w = [None] * N
for i in range(N):
# Process noise: w ~ N(0, 0.01^2)
w[i] = pm.Normal(f'w{i}', mu=0, sigma=0.001)
# Obtain measurement (without noise)
y_rand = pm.Deterministic(f'y{i}', g(x_rand))
# Add noise and measure the output
pm.Normal(f'y_meas{i}', mu=y_rand, sigma=0.01, observed=[y[i]])
# Update state: x+ = f(x) + w
x_rand = pm.Deterministic(f'x{i+1}', f(x_rand, a) + w[i])
# Create graph of the model (optional step)
#gv = pm.model_to_graphviz(toy_model)
#gv.format = 'png'
#gv.render(filename='sin_system')
# Run MCMC
with toy_model:
# Consider 2000 draws and 4 chains.
trace = pm.sample(
step = pm.Metropolis(),
tune=1200,
draws=4000,
chains=2,
cores=1,
return_inferencedata=True
)
# Plot results
# az.plot_trace(data=trace)
# plt.show()
az.plot_trace(trace.posterior["alpha"])
plt.show()
# %%
print(az.summary(trace))
# %%
az.plot_trace(trace.posterior["alpha"])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment