Skip to content

Instantly share code, notes, and snippets.

@jessegrabowski
Last active May 14, 2023 11:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jessegrabowski/ccda08b8a758f882f5794b8b89ace07a to your computer and use it in GitHub Desktop.
Save jessegrabowski/ccda08b8a758f882f5794b8b89ace07a to your computer and use it in GitHub Desktop.
ARMA-GARCH in PyMC
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@mathDR
Copy link

mathDR commented May 12, 2023

Alright, sorry to keep pestering you @jessegrabowski but one last question (maybe):
Here is my implementation, but I don't think things are correct for some reason (actually for two reasons: when I try to run with JAX backend, it is complaining about numpy generators, and if I don't run with JAX, it is using a bunch of samplers other than NUTS for different parameters (which should all be able to be sampled by NUTS since they are continuous)).

Anyway, here is the relevant part of the model (I dropped the part that computes obs_mean):

# GARCH parameters
    # Inital values for scan
    sigma_sq_init = pm.Exponential('sigma_sq_init', 1)
    epsilon_init = pm.Normal('epsilon_init')

    omega = pm.Uniform('omega',0,1)
    alpha1 = pm.Uniform('alpha1',0,1)
    beta1 = pm.Uniform('beta1',0, (1-alpha))
    
    def step(*args):
        epsilon_tm1, sigma_sq_tm1, omega, alpha1, beta1 = args
        
        # GARCH process
        sig2 = omega + alpha * at.square(epsilon_tm1) + beta1 * sigma_sq_tm1        
        
        return sig2, collect_default_updates(args, [sig2])

    sigmas, updates = pytensor.scan(
        fn=step,
        sequences=[{'input':at.concatenate([epsilon_init[None], (obs_data-obs_mean)]), 'taps':[-1]}],
        outputs_info=[sigma_sq_init],
        non_sequences=[omega, alpha1, beta1],
        strict=True,
        mode = get_mode(None) # use get_mode("JAX") if you try to sample with numpyro
    )
    sig = pm.Deterministic('sig',at.sqrt(sigmas))
    lik = pm.Normal(name='lik', mu=obs_mean, sigma=sig, observed=obs_data, shape=obs_data.shape)

where obs_data is the observed data and obs_mean is the calculated mean from the model (same shape as obs_data).

Can you see anything amiss here?

@jessegrabowski
Copy link
Author

They way you've written it, I don't think you need to collect updates at all in the scan function. This is because you're not actually making any random variables inside the scan, so you don't need to do anything special with the underlying random number generator (this is what collect_default_updates does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment