Skip to content

Instantly share code, notes, and snippets.

@kunalghosh
Last active September 22, 2023 17:11
Show Gist options
  • Save kunalghosh/bcc5b446ebc516f060090184a4c77db9 to your computer and use it in GitHub Desktop.
Save kunalghosh/bcc5b446ebc516f060090184a4c77db9 to your computer and use it in GitHub Desktop.
import gpytorch
import numpy as np
import joblib
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel() + gpytorch.kernels.LinearKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
class GenericGP:
def __init__(self, likelihood, model, n_iters):
self.likelihood_fn = likelihood # likelihood = gpytorch.likelihoods.GaussianLikelihood()
self.Model = model # model(x_train, y_train, self.likelihood) # model = ExactGPModel(X_train_, y_train_, likelihood)
self.model = None
self.n_iters = n_iters
self.initialised = False
def check_initialisation(self):
if not self.initialised:
raise RuntimeError('Please initialise GP with X_train, y_train first')
def initialise(self, x_train, y_train):
# re-sets the likelihood and model
self.initialised = True
self.x_train = x_train
self.y_train = y_train
self.likelihood = self.likelihood_fn()
self.model = self.Model(self.x_train, self.y_train, self.likelihood)
def train(self):
self.check_initialisation()
# trains n_iters times, every time it is called
# Find optimal model hyperparameters
self.model.train()
self.likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)
for i in range(self.n_iters):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = self.model(self.x_train)
# Calc loss and backprop gradients
self.loss = -mll(output, self.y_train)
self.loss.backward()
# print('Iter %d/%d - Loss: %.3f lengthscale: %.3f noise: %.3f' % (
# i + 1, training_iter, loss.item(),
# model.covar_module.base_kernel.lengthscale.item(),
# model.likelihood.noise.item()
# ))
print('Iter %d/%d - Loss: %.3f' % (
i + 1, self.n_iters, self.loss.item(),
))
optimizer.step()
def predict(self, data):
self.check_initialisation()
# Get into evaluation (predictive posterior) mode
self.model.eval()
self.likelihood.eval()
# Test points are regularly spaced along [0,1]
# Make predictions by feeding model through likelihood
with torch.no_grad(), gpytorch.settings.fast_pred_var():
observed_pred = self.likelihood(self.model(data))
return observed_pred
def get_loss(self):
self.check_initialisation()
return self.loss.item()
def get_model(self):
self.check_initialisation()
return self.model
def get_likelihood(self):
self.check_initialisation()
return self.likelihood
class ALCheckpoint:
def __init__(self, acq_func=None, batch_size_list=None, test_set_size=None,\
gp_models=None, x_train=None, y_train=None, indices_dict=None,\
iteration=None, random_seed=None, checkpoint_file='ALCheckpoint.joblib'):
self.acq_func = acq_func
self.batch_size_list = batch_size_list # batch_sizes[0] is initial batch
self.gp_models = gp_models
self.x_train = x_train
self.y_train = y_train
self.indices = indices_dict
self.iteration = iteration
self.random_seed = random_seed
self.test_set_size = test_set_size
self.checkpoint_file = checkpoint_file
def save_checkpoint(self):
joblib.dump({'acq_func':self.acq_func,
'batch_size_list':self.batch_size_list,
'gp_models':self.gp_models,
'x_train': self.x_train,
'y_train': self.y_train,
'indices_dict': self.indices,
'iteration': self.iteration,
'random_seed': self.random_seed,
'test_set_size': self.test_set_size}, self.checkpoint_file)
def load_checkpoint(self, checkpoint_file = None):
if checkpoint_file is None:
checkpoint_file = self.checkpoint_file
if Path(checkpoint_file).is_file():
checkpoint = joblib.load(checkpoint_file)
print(f'Loaded {checkpoint}')
self.__init__(**checkpoint)
# self.acq_func, self.batch_size_list, self.gp_models, self.x_train \
# ,self.y_train, self.indices, self.iteration, self.random_seed \
# ,self.test_set_size, self.checkpoint_file = *checkpoint_file
else:
print(f'Could not load {checkpoint_file}')
class GenericAL:
def __init__(self, acq_func, batch_sizes, test_set_size, gp, x_train, y_train, random_seed=42, checkpoint_file='ALCheckpoint.joblib'):
self.acq_func = acq_func
self.batch_size_list = batch_sizes # batch_sizes[0] is initial batch
self.gp_models = [gp for _ in batch_sizes]
self.x_train = x_train
self.y_train = y_train
self.indices = [{'train':None, 'test':None, 'heldout': None} for _ in batch_sizes]
assert len(x_train) == len(y_train), 'Number of datatapoints in X and Y must be same'
assert len(y_train) >= np.sum(batch_sizes) + test_set_size, 'Total number of molecules in dataset must be greater or equal to number of molecules picked in all batches + test set size.'
self.iteration = None
self.random_seed = random_seed
self.test_set_size = test_set_size
# check if model save file exists
if Path(checkpoint_file).is_file():
# load it and set model state
self.checkpoint = ALCheckpoint()
self.checkpoint.load_checkpoint()
self.acq_func = self.checkpoint.acq_func
self.batch_size_list = self.checkpoint.batch_size_list # batch_sizes[0] is initial batch
self.gp_models = self.checkpoint.gp_models
self.x_train = self.checkpoint.x_train
self.y_train = self.checkpoint.y_train
self.indices = self.checkpoint.indices
self.random_seed = self.checkpoint.random_seed
self.test_set_size = self.checkpoint.test_set_size
print(type(self.checkpoint.iteration), self.checkpoint.iteration)
self.iteration = self.checkpoint.iteration - 1 # we repeat the last iteration, difficult to know at which point it failed.
else:
self.checkpoint = ALCheckpoint(self.acq_func, self.batch_size_list,\
self.test_set_size, self.gp_models,\
self.x_train, self.y_train, \
self.indices, self.iteration, self.random_seed)
def initialise(self):
# set current iteration
self.iteration = 0
# Select test set
heldout_indices, test_indices = train_test_split(np.arange(len(self.y_train)), test_size = self.test_set_size, random_state=self.random_seed)
# Select training set
train_indices, heldout_indices = train_test_split(heldout_indices, train_size = self.batch_size_list[self.iteration], random_state=self.random_seed)
# set the indices in dict
self.indices[self.iteration]['train'], self.indices[self.iteration]['test'], self.indices[self.iteration]['heldout'] = train_indices, test_indices, heldout_indices
# check the train test split worked fine
assert set(range(len(self.y_train))) == set(test_indices).union(set(train_indices)).union(set(heldout_indices))
# check intersections are empty
assert len(set(test_indices) & set(train_indices)) == 0, 'test and train sets must be mutually exclusive'
assert len(set(test_indices) & set(heldout_indices)) == 0, 'test and heldout sets must be mutually exclusive'
assert len(set(train_indices) & set(heldout_indices)) == 0, 'train and heldout sets must be mutually exclusive'
# Intial GP training
x_train_curr_iteration = self.x_train[self.indices[self.iteration]['train']]
y_train_curr_iteration = self.y_train[self.indices[self.iteration]['train']]
self.gp_models[self.iteration].initialise(x_train = x_train_curr_iteration, y_train = y_train_curr_iteration)
# Do one round of training
self.gp_models[self.iteration].train()
# get test predictions from previous iteration
self.make_test_predictions()
# copy current indices to next index list
self.indices[self.iteration+1]['train'] = self.indices[self.iteration]['train']
self.indices[self.iteration+1]['heldout'] = self.indices[self.iteration]['heldout']
self.indices[self.iteration+1]['test'] = self.indices[self.iteration]['test']
# save metadata to ALMetadataClass
self.checkpoint.iteration = self.iteration
self.checkpoint.gp_models = self.gp_models
self.checkpoint.indices = self.indices
self.checkpoint.save_checkpoint()
def execute_al_loop(self):
self.iteration += 1
if self.iteration == len(self.batch_size_list):
print("FINISHED: all AL Iterations over !")
return
# indices
train_indices, heldout_indices = self.indices[self.iteration]['train'], self.indices[self.iteration]['heldout']
# make heldout set predictions based on previous iter gp model
predictions_heldout = self.gp_models[self.iteration-1].predict(self.x_train[heldout_indices])
# run acq_fn on heldout set to get new_trainset_indicies
num_toselect_curr_iter = self.batch_size_list[self.iteration]# number of datapoints to pick in the current iteration
new_trainset_indices = self.acq_func(num_toselect_curr_iter, heldout_indices, predictions_heldout)
print(new_trainset_indices, heldout_indices)
assert set(new_trainset_indices).issubset(set(heldout_indices))
# remove new_trainset_indices from heldout_indices
heldout_indices = np.array(list(set(heldout_indices).difference(set(new_trainset_indices))))
# add new_trainset_indices to train_indices
train_indices = np.array(list(set(new_trainset_indices).union(set(train_indices))))
# save metadata to class variables
self.indices[self.iteration]['train'], self.indices[self.iteration]['heldout'] = train_indices, heldout_indices
# If model fails mid way, start from here....
# train new gp
x_train_curr_iteration = self.x_train[train_indices]
y_train_curr_iteration = self.y_train[train_indices]
self.gp_models[self.iteration].initialise(x_train = x_train_curr_iteration, y_train = y_train_curr_iteration)
self.gp_models[self.iteration].train()
# get test predictions from previous iteration
self.make_test_predictions()
if self.iteration+1 < len(self.batch_size_list):
print('copying data to next iteration')
# copy current indices to next index list
self.indices[self.iteration+1]['train'] = self.indices[self.iteration]['train']
self.indices[self.iteration+1]['heldout'] = self.indices[self.iteration]['heldout']
self.indices[self.iteration+1]['test'] = self.indices[self.iteration]['test']
# save metadata to ALMetadataClass
self.checkpoint.iteration = self.iteration
self.checkpoint.gp_models = self.gp_models
self.checkpoint.indices = self.indices
self.checkpoint.save_checkpoint()
def make_test_predictions(self):
test_indices = self.indices[self.iteration]['test']
test_predictions = self.gp_models[self.iteration].predict(self.x_train[test_indices])
print(self.iteration, 'MSE', mean_squared_error(test_predictions.mean, self.y_train[test_indices]), 'eV')
def acq_func_f(num_toselect, heldout_indices, predictions_heldout):
# sort predictions based on uncertainty
stddev = predictions_heldout.stddev.numpy()
sort_idxs = np.argsort(stddev)
top_half_most_uncertain = sort_idxs[-len(heldout_indices)//2:]
selected_indices = np.random.choice(top_half_most_uncertain, size=num_toselect, replace=False)
print(type(heldout_indices), type(selected_indices), selected_indices)
return heldout_indices[selected_indices]
# gp = GenericGP(likelihood = gpytorch.likelihoods.GaussianLikelihood,
# model = ExactGPModel,
# n_iters=1000)
# al = GenericAL(acq_func_f, [10,10,10,10], 10, gp, X_train_, y_train_)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment