Skip to content

Instantly share code, notes, and snippets.

@jakeyeung
Created August 9, 2023 15:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jakeyeung/d0e0daad0ea10ebf674976df329085f4 to your computer and use it in GitHub Desktop.
Save jakeyeung/d0e0daad0ea10ebf674976df329085f4 to your computer and use it in GitHub Desktop.
Check obs_noise from GP fit using constant kernel
# import blackjax # terrible to install
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax.config import config

from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
import gpjax as gpx

import numpy as np

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
tfd = tfp.distributions
key = jr.PRNGKey(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]


import datetime
jdate = datetime.datetime.now().date()

# force CPU: save memory
jax.config.update('jax_platform_name', 'cpu')

print("Checking GPU or CPU:")
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
jax.default_backend()
jax.devices()
print("Checking done")




/nfs/scistore12/hpcgrp/jyeung/miniconda3/envs/gpjax-0.6.7/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm


Checking GPU or CPU:
cpu
Checking done
# simulate some data
mean1 = 6
var1 = 9
obs_noise1 = jnp.sqrt(var1)
N = 500

tmax = 1
xvec = jnp.linspace(0, tmax, N, endpoint = False).reshape(-1, 1)
yvec1 = mean1 + jax.random.normal(jr.PRNGKey(1), (N, 1)) * obs_noise1

fig, axs = plt.subplots(1)
axs.hist(yvec1.reshape(-1), bins = 50, alpha = 0.5, label = 'yvec1')
(array([ 2.,  0.,  1.,  0.,  7.,  7.,  4.,  1.,  8.,  4., 12.,  8., 11.,
        10., 12., 18., 12., 15., 23., 21., 32., 18., 24., 12., 22., 27.,
        22., 20., 17., 19., 17., 14., 15., 14., 16.,  8.,  7.,  5.,  1.,
         5.,  1.,  4.,  2.,  0.,  1.,  0.,  0.,  0.,  0.,  1.]),
 array([-2.590723  , -2.22452536, -1.85832773, -1.4921301 , -1.12593247,
        -0.75973484, -0.39353721, -0.02733957,  0.33885806,  0.70505569,
         1.07125332,  1.43745095,  1.80364858,  2.16984622,  2.53604385,
         2.90224148,  3.26843911,  3.63463674,  4.00083437,  4.36703201,
         4.73322964,  5.09942727,  5.4656249 ,  5.83182253,  6.19802016,
         6.5642178 ,  6.93041543,  7.29661306,  7.66281069,  8.02900832,
         8.39520595,  8.76140359,  9.12760122,  9.49379885,  9.85999648,
        10.22619411, 10.59239174, 10.95858938, 11.32478701, 11.69098464,
        12.05718227, 12.4233799 , 12.78957753, 13.15577517, 13.5219728 ,
        13.88817043, 14.25436806, 14.62056569, 14.98676332, 15.35296096,
        15.71915859]),
 <BarContainer object of 50 artists>)

png

var1check = jnp.var(yvec1, ddof = 1)
print(var1check)
9.792602878664741
# fit constant kernel, check obs_noise
from dataclasses import dataclass
from gpjax.base.param import param_field
import tensorflow_probability.substrates.jax.bijectors as tfb
from gpjax.typing import (
    Array,
    ScalarFloat,
)
from jaxtyping import Float, Integer

@dataclass
class ConstantKernelToy(gpx.kernels.AbstractKernel):
    r"""Constant kernel"""
    sigma2: ScalarFloat = param_field(jnp.array(0.01), bijector=tfb.Softplus(), trainable = True)
    name: str="Constant"

    def __call__(self, x1: Float[Array, " D"], x2: Float[Array, " D"]) -> ScalarFloat:
        r"""Compute constant kernel

        # from GPFlow: https://github.com/GPflow/GPflow/blob/develop/gpflow/kernels/statics.py
        The Constant (aka Bias) kernel. Functions drawn from a GP with this kernel
        are constant, i.e. f(x) = c, with c ~ N(0, sigma2). The kernel equation is

            k(x, y) = sigma2

        where:
        sigma2 is the variance parameter.
        """
        return(self.sigma2.squeeze())
        
Data = gpx.Dataset(X=xvec, y=yvec1)


# mean1prior = jnp.float64(4)
meanf = gpx.mean_functions.Constant()

sigma2prior = jnp.float64(0.0000001)
# kernelwhite = gpx.kernels.White()
kernelconstant = ConstantKernelToy(sigma2 = sigma2prior)

# draw samples
prior = gpx.Prior(mean_function=meanf, kernel=kernelconstant)

# prior = gpx.Prior(mean_function=meanf, kernel=kernel)
xtest = jnp.float64(Data.X[:, 0]).reshape(-1, 1)
prior_dist = prior.predict(xtest)
prior_mean = prior_dist.mean()
prior_std = prior_dist.variance()
samples = prior_dist.sample(seed=key, sample_shape=(20,))
fig, ax = plt.subplots()
ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], marker='o')
ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean")
ax.fill_between(
    xtest.flatten(),
    prior_mean - prior_std,
    prior_mean + prior_std,
    alpha=0.3,
    color=cols[1],
    label="Prior variance",
)
ax.legend(loc="best")
<matplotlib.legend.Legend at 0x14e16c91b550>

png

likelihood = gpx.Gaussian(num_datapoints=Data.n)
posterior = prior * likelihood
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll = jax.jit(negative_mll)
negative_mll(posterior, train_data=Data)

import optax as ox
opt_posterior, history = gpx.fit(
    model=posterior,
    objective=negative_mll,
    train_data=Data,
    optim=ox.adam(learning_rate=0.01),
    num_iters=5000,
    safe=True,
    key=key,
)
Running: 100%|██████████| 5000/5000 [00:46<00:00, 106.86it/s, Value=1295.02]
fig, ax = plt.subplots()
ax.plot(history, color=cols[1])
ax.set(xlabel="Training iteration", ylabel="Negative marginal log likelihood")
[Text(0.5, 0, 'Training iteration'),
 Text(0, 0.5, 'Negative marginal log likelihood')]

png

latent_dist = opt_posterior.predict(xtest, train_data=Data)
predictive_dist = opt_posterior.likelihood(latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
fig, ax = plt.subplots(figsize=(7.5, 2.5))
ax.plot(xvec, yvec1, "x", label="Observations", color=cols[0], alpha=0.5)
ax.fill_between(
    xtest.squeeze(),
    predictive_mean - 2 * predictive_std,
    predictive_mean + 2 * predictive_std,
    alpha=0.2,
    label="Two sigma",
    color=cols[1],
)
ax.plot(
    xtest,
    predictive_mean - 2 * predictive_std,
    linestyle="--",
    linewidth=1,
    color=cols[1],
)
ax.plot(
    xtest,
    predictive_mean + 2 * predictive_std,
    linestyle="--",
    linewidth=1,
    color=cols[1],
)

ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])
ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))
<matplotlib.legend.Legend at 0x14e13a170b90>

png

print('Real obs noise: %s' % obs_noise1)
print('Estimated obs noise: %s' % opt_posterior.likelihood.obs_noise)
Real obs noise: 3.0
Estimated obs noise: 7.2693954
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment