Skip to content

Instantly share code, notes, and snippets.

@twiecki
Created January 4, 2022 03:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save twiecki/86b02349c60385eb6d77793d37bd96a9 to your computer and use it in GitHub Desktop.
Save twiecki/86b02349c60385eb6d77793d37bd96a9 to your computer and use it in GitHub Desktop.
class ModelBuilder:
@classmethod
def load_default_config(cls):
# return dict of parameters
...
def __init__(self, data: pd.DataFrame, trait: str, config: Dict):
"""
Initialize the model builder.
Parameters
----------
data: pd.DataFrame
Dataframe containing the raw, uncleaned field data.
config: Dict
Dictionary of configuration for the models' priors and variables.
See :func:`ModelBuilder.load_default_config` for an example.
"""
self.config = copy.deepcopy(config)
self.data, self.coords = self._clean_data(data)
self.model_type = None # Attribute for the type of bayesian model
self.model = None # Attribute for the pymc3 model
self.idata = None # Attribute for the az.InferenceData result
self.run_id = -1 # Attribute for the bayesian run id
def _clean_data(
self,
data: pd.DataFrame
) -> Tuple[pd.DataFrame, Dict[str, Union[pd.CategoricalIndex, pd.Series]]]:
"""
Clean the data passed to the model.
Parameters
----------
data: pd.DataFrame
Dataframe containing the raw, uncleaned field data.
Returns
----------
Tuple[pd.DataFrame, Dict[str, Union[pd.CategoricalIndex, pd.Series]]]
The clean field data, as well as the dictionary of coordinates for the model.
Notes
-----
Filter out zero counts.
Define the model coords.
"""
data = data.copy()
coords = {}
return data, coords
def build(self) -> pm.Model:
"""
Build the single field model for the given data, trait and config.
"""
return self.model
def sample(self, *, model: pm.Model = None, **kwargs) -> az.InferenceData:
"""Sample the model and return the trace.
Parameters
----------
model : optional
A model previously created using `self.build()`. Build
a new model if None.
**kwargs : dict
Additional arguments to `pm.sample`
"""
if model is None and self.model is None:
model = self.build()
elif model is None:
model = self.model
with model:
trace = pm.sample(return_inferencedata=False, **self.config["sampler"], **kwargs)
ppc = pm.sample_posterior_predictive(trace)
prior = pm.sample_prior_predictive()
idata = az.from_pymc3(
trace=trace,
prior=prior,
posterior_predictive=ppc,
model=model,
)
self.idata = idata
return idata
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment