Skip to content

Instantly share code, notes, and snippets.

@lwiklendt
Last active November 11, 2020 18:58
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lwiklendt/9c7099288f85b59edc903a5aed2d2d64 to your computer and use it in GitHub Desktop.
Save lwiklendt/9c7099288f85b59edc903a5aed2d2d64 to your computer and use it in GitHub Desktop.
Extract parameter samples from PyStan's vb method, so that it resembles extract() from the sampling method
import numpy as np
from collections import OrderedDict
def pystan_vb_extract(results):
param_specs = results['sampler_param_names']
samples = results['sampler_params']
n = len(samples[0])
# first pass, calculate the shape
param_shapes = OrderedDict()
for param_spec in param_specs:
splt = param_spec.split('[')
name = splt[0]
if len(splt) > 1:
idxs = [int(i) for i in splt[1][:-1].split(',')] # no +1 for shape calculation because pystan already returns 1-based indexes for vb!
else:
idxs = ()
param_shapes[name] = np.maximum(idxs, param_shapes.get(name, idxs))
# create arrays
params = OrderedDict([(name, np.nan * np.empty((n, ) + tuple(shape))) for name, shape in param_shapes.items()])
# second pass, set arrays
for param_spec, param_samples in zip(param_specs, samples):
splt = param_spec.split('[')
name = splt[0]
if len(splt) > 1:
idxs = [int(i) - 1 for i in splt[1][:-1].split(',')] # -1 because pystan returns 1-based indexes for vb!
else:
idxs = ()
params[name][(..., ) + tuple(idxs)] = param_samples
return params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment