Created
December 13, 2023 16:16
-
-
Save matt-graham/c471d88505570c85159aeee3dbbeb8a0 to your computer and use it in GitHub Desktop.
Sketching out block diagonal kernel implementation in GPJax for Kennedy & O'Hagan calibration framework
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import Iterator | |
from dataclasses import dataclass | |
from cola import PSD | |
from cola.ops import BlockDiag, LinearOperator | |
from gpjax.base import static_field | |
from gpjax.kernels import AbstractKernel | |
from gpjax.kernels.computations.base import AbstractKernelComputation, Kernel | |
from gpjax.kernels.computations import DenseKernelComputation | |
from gpjax.typing import Array, ScalarFloat | |
from jaxtyping import ( | |
Float, | |
Num, | |
) | |
def partition(array: Array, block_sizes: list[int]) -> Iterator[Array]: | |
index = 0 | |
for block_size in block_sizes: | |
yield array[index : index + block_size] | |
index += block_size | |
if index < array.shape[0]: | |
yield array[index:] | |
class BlockDiagonalKernelComputation(AbstractKernelComputation): | |
def gram( | |
self, | |
kernel: "BlockDiagonalKernel", | |
x: Num[Array, "N D"], | |
) -> LinearOperator: | |
return PSD( | |
BlockDiag( | |
*( | |
block_kernel.gram(x_block) | |
for block_kernel, x_block in zip( | |
kernel.block_kernels, | |
partition(x, kernel.block_sizes), | |
strict=True, | |
) | |
), | |
), | |
) | |
def cross_covariance( | |
self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"] | |
) -> Float[Array, "N M"]: | |
# Assume y only correlated with x in first block | |
return kernel.block_kernels[0].cross_covariance(x, y) | |
@dataclass | |
class BlockDiagonalKernel(AbstractKernel): | |
block_kernels: list[AbstractKernel] | None = None | |
block_sizes: list[int] | None = static_field(None) | |
compute_engine: AbstractKernelComputation = static_field( | |
BlockDiagonalKernelComputation(), | |
) | |
def __call__( | |
self, | |
x: Num[Array, " D"], | |
y: Num[Array, " D"], | |
) -> ScalarFloat: | |
raise NotImplementedError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment