Skip to content

Instantly share code, notes, and snippets.

@matt-graham
Created December 13, 2023 16:16
Show Gist options
  • Save matt-graham/c471d88505570c85159aeee3dbbeb8a0 to your computer and use it in GitHub Desktop.
Save matt-graham/c471d88505570c85159aeee3dbbeb8a0 to your computer and use it in GitHub Desktop.
Sketching out block diagonal kernel implementation in GPJax for Kennedy & O'Hagan calibration framework
from collections.abc import Iterator
from dataclasses import dataclass
from cola import PSD
from cola.ops import BlockDiag, LinearOperator
from gpjax.base import static_field
from gpjax.kernels import AbstractKernel
from gpjax.kernels.computations.base import AbstractKernelComputation, Kernel
from gpjax.kernels.computations import DenseKernelComputation
from gpjax.typing import Array, ScalarFloat
from jaxtyping import (
Float,
Num,
)
def partition(array: Array, block_sizes: list[int]) -> Iterator[Array]:
index = 0
for block_size in block_sizes:
yield array[index : index + block_size]
index += block_size
if index < array.shape[0]:
yield array[index:]
class BlockDiagonalKernelComputation(AbstractKernelComputation):
def gram(
self,
kernel: "BlockDiagonalKernel",
x: Num[Array, "N D"],
) -> LinearOperator:
return PSD(
BlockDiag(
*(
block_kernel.gram(x_block)
for block_kernel, x_block in zip(
kernel.block_kernels,
partition(x, kernel.block_sizes),
strict=True,
)
),
),
)
def cross_covariance(
self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
) -> Float[Array, "N M"]:
# Assume y only correlated with x in first block
return kernel.block_kernels[0].cross_covariance(x, y)
@dataclass
class BlockDiagonalKernel(AbstractKernel):
block_kernels: list[AbstractKernel] | None = None
block_sizes: list[int] | None = static_field(None)
compute_engine: AbstractKernelComputation = static_field(
BlockDiagonalKernelComputation(),
)
def __call__(
self,
x: Num[Array, " D"],
y: Num[Array, " D"],
) -> ScalarFloat:
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment