Skip to content

Instantly share code, notes, and snippets.

@jessegrabowski
Last active May 14, 2023 11:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jessegrabowski/ccda08b8a758f882f5794b8b89ace07a to your computer and use it in GitHub Desktop.
Save jessegrabowski/ccda08b8a758f882f5794b8b89ace07a to your computer and use it in GitHub Desktop.
ARMA-GARCH in PyMC
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@jessegrabowski
Copy link
Author

jessegrabowski commented May 11, 2023

Ah ok! Yeah it definitely makes sense to have some kind of COVID adjustment. I saw a paper on LinkedIn that did this kind of adjustment in a Bayesian VAR (I dug around for the post but couldn't find it). Basically they just used an indicator variable for COVID/not COVID, then modeled the residuals of that regression with a normal VAR. You would basically be doing the same thing in a GARCH framework by having your exog_data be a COVID indicator?

@mathDR
Copy link

mathDR commented May 11, 2023

So we had no indicator originally. And COVID is a really simple example (we absolutely knew what was going on) but generically, if you see large std in certain time regimes, it is a motivator to go see why the std is growing (i.e. why aren't the features capturing the variance?). Then you can apply the COVID indicator (for example) and (hopefully) see the variance diminish during the same time period (when retrained). Then you can be confident that the uncaptured variance was in fact due to COVID (and not a bunch of other things).

@mathDR
Copy link

mathDR commented May 12, 2023

Alright, sorry to keep pestering you @jessegrabowski but one last question (maybe):
Here is my implementation, but I don't think things are correct for some reason (actually for two reasons: when I try to run with JAX backend, it is complaining about numpy generators, and if I don't run with JAX, it is using a bunch of samplers other than NUTS for different parameters (which should all be able to be sampled by NUTS since they are continuous)).

Anyway, here is the relevant part of the model (I dropped the part that computes obs_mean):

# GARCH parameters
    # Inital values for scan
    sigma_sq_init = pm.Exponential('sigma_sq_init', 1)
    epsilon_init = pm.Normal('epsilon_init')

    omega = pm.Uniform('omega',0,1)
    alpha1 = pm.Uniform('alpha1',0,1)
    beta1 = pm.Uniform('beta1',0, (1-alpha))
    
    def step(*args):
        epsilon_tm1, sigma_sq_tm1, omega, alpha1, beta1 = args
        
        # GARCH process
        sig2 = omega + alpha * at.square(epsilon_tm1) + beta1 * sigma_sq_tm1        
        
        return sig2, collect_default_updates(args, [sig2])

    sigmas, updates = pytensor.scan(
        fn=step,
        sequences=[{'input':at.concatenate([epsilon_init[None], (obs_data-obs_mean)]), 'taps':[-1]}],
        outputs_info=[sigma_sq_init],
        non_sequences=[omega, alpha1, beta1],
        strict=True,
        mode = get_mode(None) # use get_mode("JAX") if you try to sample with numpyro
    )
    sig = pm.Deterministic('sig',at.sqrt(sigmas))
    lik = pm.Normal(name='lik', mu=obs_mean, sigma=sig, observed=obs_data, shape=obs_data.shape)

where obs_data is the observed data and obs_mean is the calculated mean from the model (same shape as obs_data).

Can you see anything amiss here?

@jessegrabowski
Copy link
Author

They way you've written it, I don't think you need to collect updates at all in the scan function. This is because you're not actually making any random variables inside the scan, so you don't need to do anything special with the underlying random number generator (this is what collect_default_updates does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment