Skip to content

Instantly share code, notes, and snippets.

@eytan
Created March 25, 2020 17:33
Show Gist options
  • Save eytan/2335191fe1bdabf6d1b94d5ad53beec1 to your computer and use it in GitHub Desktop.
Save eytan/2335191fe1bdabf6d1b94d5ad53beec1 to your computer and use it in GitHub Desktop.
# due to balandat
from botorch.models import SingleTaskGP
from torch import Tensor
class LinearSingleTaskGP(SingleTaskGP):
def __init__(self, train_X: Tensor, train_Y: Tensor, **kwargs):
super().__init__(train_X, train_Y, **kwargs)
self.mean_module = LinearMean(batch_shape=train_X.shape[:-2], input_size=train_X.size(-1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment