Last active
July 8, 2022 06:11
-
-
Save 5hv5hvnk/edec05d59f1d912d4c23df90ea51193f to your computer and use it in GitHub Desktop.
Example for New class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cloudpickle | |
from pathlib import Path | |
import os | |
import arviz as az | |
import numpy as np | |
import pymc as pm | |
import trace | |
import logging | |
import sys | |
import time | |
import warnings | |
from collections import defaultdict | |
from copy import copy | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Iterable, | |
Iterator, | |
List, | |
Optional, | |
Sequence, | |
Set, | |
Tuple, | |
Union, | |
cast, | |
) | |
from pymc.initial_point import ( | |
PointType, | |
StartDict, | |
filter_rvs_to_jitter, | |
make_initial_point_fns_per_chain, | |
) | |
import numpy as np | |
from scipy.special import expit | |
import pandas as pd | |
RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]] | |
RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator] | |
class base: | |
def __init__(self): | |
self.model = None | |
self.trace = None | |
self.build = False | |
self.model = None | |
def build_model(self): | |
self.model = pm.Model() | |
def save(self,file_prefix,filepath,save_format=None): | |
if save_format == 'h5': | |
extension = '.hdf5' | |
else: | |
extension = '.pickle' | |
filepath = Path(filepath+str(file_prefix)+extension) | |
Model = cloudpickle.dumps(self.model) | |
file = open(filepath, 'wb') | |
file.write(Model) | |
self.saved = True | |
print("Model Saved") | |
def load(self,filepath): | |
filepath = Path(filepath) | |
file = open(filepath,'rb') | |
base.build_model(self) | |
self.model = cloudpickle.loads(file.read()) | |
def fit( | |
self, | |
draws: int = 1000, | |
step=None, | |
init: str = "auto", | |
n_init: int = 200_000, | |
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, | |
trace = None, | |
chain_idx: int = 0, | |
chains: Optional[int] = None, | |
cores: Optional[int] = None, | |
tune: int = 1000, | |
progressbar: bool = True, | |
model=None, | |
random_seed: RandomState = None, | |
discard_tuned_samples: bool = True, | |
compute_convergence_checks: bool = True, | |
callback=None, | |
jitter_max_retries: int = 10, | |
return_inferencedata: bool = True, | |
idata_kwargs: dict = None, | |
mp_ctx=None, | |
**kwargs, | |
): | |
with self.model: | |
trace = pm.sample( | |
draws, | |
step, | |
init, | |
n_init, | |
initvals, | |
trace, | |
chain_idx, | |
chains, | |
cores, | |
tune, | |
progressbar, | |
model, | |
random_seed, | |
discard_tuned_samples, | |
compute_convergence_checks, | |
callback, | |
jitter_max_retries, | |
return_inferencedata=return_inferencedata, | |
idata_kwargs=idata_kwargs, | |
mp_ctx=mp_ctx, | |
**kwargs, | |
) | |
self.trace = trace | |
return trace | |
def predict(self, | |
X, | |
samples: int = 500, | |
var_names: Optional[Iterable[str]] = None, | |
random_seed=None, | |
return_inferencedata: bool = True, | |
idata_kwargs: dict = None, | |
compile_kwargs: dict = None | |
): | |
with self.model: | |
pm.set_data(X) | |
idata = self.trace | |
y_pred = pm.sample_posterior_predictive(idata) | |
return y_pred | |
class BRM_MyModel(base): | |
def __init__(self,n=0,data=None): | |
base.build_model(self) | |
if isinstance(data, pd.DataFrame)==False: | |
rng = np.random.default_rng(1234) | |
beta0_true = 0.7 | |
beta1_true = 0.4 | |
n = 20 | |
sample_size = 30 | |
x = np.linspace(-10, 20, sample_size) | |
mu_true = beta0_true + beta1_true * x | |
p_true = expit(mu_true) | |
y = rng.binomial(n, p_true) | |
data = pd.DataFrame({"x": x, "y": y}) | |
with self.model: | |
x = pm.ConstantData("x", data["x"], dims="observation") | |
beta0 = pm.Normal("beta0", mu=0, sigma=1) | |
beta1 = pm.Normal("beta1", mu=0, sigma=1) | |
mu = beta0 + beta1 * x | |
p = pm.Deterministic("p", pm.math.invlogit(mu), dims="observation") | |
pm.Binomial("y", n=n, p=p, observed=data["y"], dims="observation") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment