Last active
September 22, 2023 17:11
-
-
Save kunalghosh/bcc5b446ebc516f060090184a4c77db9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gpytorch | |
import numpy as np | |
import joblib | |
from pathlib import Path | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import mean_squared_error | |
class ExactGPModel(gpytorch.models.ExactGP): | |
def __init__(self, train_x, train_y, likelihood): | |
super(ExactGPModel, self).__init__(train_x, train_y, likelihood) | |
self.mean_module = gpytorch.means.ConstantMean() | |
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel() + gpytorch.kernels.LinearKernel()) | |
def forward(self, x): | |
mean_x = self.mean_module(x) | |
covar_x = self.covar_module(x) | |
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | |
class GenericGP: | |
def __init__(self, likelihood, model, n_iters): | |
self.likelihood_fn = likelihood # likelihood = gpytorch.likelihoods.GaussianLikelihood() | |
self.Model = model # model(x_train, y_train, self.likelihood) # model = ExactGPModel(X_train_, y_train_, likelihood) | |
self.model = None | |
self.n_iters = n_iters | |
self.initialised = False | |
def check_initialisation(self): | |
if not self.initialised: | |
raise RuntimeError('Please initialise GP with X_train, y_train first') | |
def initialise(self, x_train, y_train): | |
# re-sets the likelihood and model | |
self.initialised = True | |
self.x_train = x_train | |
self.y_train = y_train | |
self.likelihood = self.likelihood_fn() | |
self.model = self.Model(self.x_train, self.y_train, self.likelihood) | |
def train(self): | |
self.check_initialisation() | |
# trains n_iters times, every time it is called | |
# Find optimal model hyperparameters | |
self.model.train() | |
self.likelihood.train() | |
# Use the adam optimizer | |
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters | |
# "Loss" for GPs - the marginal log likelihood | |
mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model) | |
for i in range(self.n_iters): | |
# Zero gradients from previous iteration | |
optimizer.zero_grad() | |
# Output from model | |
output = self.model(self.x_train) | |
# Calc loss and backprop gradients | |
self.loss = -mll(output, self.y_train) | |
self.loss.backward() | |
# print('Iter %d/%d - Loss: %.3f lengthscale: %.3f noise: %.3f' % ( | |
# i + 1, training_iter, loss.item(), | |
# model.covar_module.base_kernel.lengthscale.item(), | |
# model.likelihood.noise.item() | |
# )) | |
print('Iter %d/%d - Loss: %.3f' % ( | |
i + 1, self.n_iters, self.loss.item(), | |
)) | |
optimizer.step() | |
def predict(self, data): | |
self.check_initialisation() | |
# Get into evaluation (predictive posterior) mode | |
self.model.eval() | |
self.likelihood.eval() | |
# Test points are regularly spaced along [0,1] | |
# Make predictions by feeding model through likelihood | |
with torch.no_grad(), gpytorch.settings.fast_pred_var(): | |
observed_pred = self.likelihood(self.model(data)) | |
return observed_pred | |
def get_loss(self): | |
self.check_initialisation() | |
return self.loss.item() | |
def get_model(self): | |
self.check_initialisation() | |
return self.model | |
def get_likelihood(self): | |
self.check_initialisation() | |
return self.likelihood | |
class ALCheckpoint: | |
def __init__(self, acq_func=None, batch_size_list=None, test_set_size=None,\ | |
gp_models=None, x_train=None, y_train=None, indices_dict=None,\ | |
iteration=None, random_seed=None, checkpoint_file='ALCheckpoint.joblib'): | |
self.acq_func = acq_func | |
self.batch_size_list = batch_size_list # batch_sizes[0] is initial batch | |
self.gp_models = gp_models | |
self.x_train = x_train | |
self.y_train = y_train | |
self.indices = indices_dict | |
self.iteration = iteration | |
self.random_seed = random_seed | |
self.test_set_size = test_set_size | |
self.checkpoint_file = checkpoint_file | |
def save_checkpoint(self): | |
joblib.dump({'acq_func':self.acq_func, | |
'batch_size_list':self.batch_size_list, | |
'gp_models':self.gp_models, | |
'x_train': self.x_train, | |
'y_train': self.y_train, | |
'indices_dict': self.indices, | |
'iteration': self.iteration, | |
'random_seed': self.random_seed, | |
'test_set_size': self.test_set_size}, self.checkpoint_file) | |
def load_checkpoint(self, checkpoint_file = None): | |
if checkpoint_file is None: | |
checkpoint_file = self.checkpoint_file | |
if Path(checkpoint_file).is_file(): | |
checkpoint = joblib.load(checkpoint_file) | |
print(f'Loaded {checkpoint}') | |
self.__init__(**checkpoint) | |
# self.acq_func, self.batch_size_list, self.gp_models, self.x_train \ | |
# ,self.y_train, self.indices, self.iteration, self.random_seed \ | |
# ,self.test_set_size, self.checkpoint_file = *checkpoint_file | |
else: | |
print(f'Could not load {checkpoint_file}') | |
class GenericAL: | |
def __init__(self, acq_func, batch_sizes, test_set_size, gp, x_train, y_train, random_seed=42, checkpoint_file='ALCheckpoint.joblib'): | |
self.acq_func = acq_func | |
self.batch_size_list = batch_sizes # batch_sizes[0] is initial batch | |
self.gp_models = [gp for _ in batch_sizes] | |
self.x_train = x_train | |
self.y_train = y_train | |
self.indices = [{'train':None, 'test':None, 'heldout': None} for _ in batch_sizes] | |
assert len(x_train) == len(y_train), 'Number of datatapoints in X and Y must be same' | |
assert len(y_train) >= np.sum(batch_sizes) + test_set_size, 'Total number of molecules in dataset must be greater or equal to number of molecules picked in all batches + test set size.' | |
self.iteration = None | |
self.random_seed = random_seed | |
self.test_set_size = test_set_size | |
# check if model save file exists | |
if Path(checkpoint_file).is_file(): | |
# load it and set model state | |
self.checkpoint = ALCheckpoint() | |
self.checkpoint.load_checkpoint() | |
self.acq_func = self.checkpoint.acq_func | |
self.batch_size_list = self.checkpoint.batch_size_list # batch_sizes[0] is initial batch | |
self.gp_models = self.checkpoint.gp_models | |
self.x_train = self.checkpoint.x_train | |
self.y_train = self.checkpoint.y_train | |
self.indices = self.checkpoint.indices | |
self.random_seed = self.checkpoint.random_seed | |
self.test_set_size = self.checkpoint.test_set_size | |
print(type(self.checkpoint.iteration), self.checkpoint.iteration) | |
self.iteration = self.checkpoint.iteration - 1 # we repeat the last iteration, difficult to know at which point it failed. | |
else: | |
self.checkpoint = ALCheckpoint(self.acq_func, self.batch_size_list,\ | |
self.test_set_size, self.gp_models,\ | |
self.x_train, self.y_train, \ | |
self.indices, self.iteration, self.random_seed) | |
def initialise(self): | |
# set current iteration | |
self.iteration = 0 | |
# Select test set | |
heldout_indices, test_indices = train_test_split(np.arange(len(self.y_train)), test_size = self.test_set_size, random_state=self.random_seed) | |
# Select training set | |
train_indices, heldout_indices = train_test_split(heldout_indices, train_size = self.batch_size_list[self.iteration], random_state=self.random_seed) | |
# set the indices in dict | |
self.indices[self.iteration]['train'], self.indices[self.iteration]['test'], self.indices[self.iteration]['heldout'] = train_indices, test_indices, heldout_indices | |
# check the train test split worked fine | |
assert set(range(len(self.y_train))) == set(test_indices).union(set(train_indices)).union(set(heldout_indices)) | |
# check intersections are empty | |
assert len(set(test_indices) & set(train_indices)) == 0, 'test and train sets must be mutually exclusive' | |
assert len(set(test_indices) & set(heldout_indices)) == 0, 'test and heldout sets must be mutually exclusive' | |
assert len(set(train_indices) & set(heldout_indices)) == 0, 'train and heldout sets must be mutually exclusive' | |
# Intial GP training | |
x_train_curr_iteration = self.x_train[self.indices[self.iteration]['train']] | |
y_train_curr_iteration = self.y_train[self.indices[self.iteration]['train']] | |
self.gp_models[self.iteration].initialise(x_train = x_train_curr_iteration, y_train = y_train_curr_iteration) | |
# Do one round of training | |
self.gp_models[self.iteration].train() | |
# get test predictions from previous iteration | |
self.make_test_predictions() | |
# copy current indices to next index list | |
self.indices[self.iteration+1]['train'] = self.indices[self.iteration]['train'] | |
self.indices[self.iteration+1]['heldout'] = self.indices[self.iteration]['heldout'] | |
self.indices[self.iteration+1]['test'] = self.indices[self.iteration]['test'] | |
# save metadata to ALMetadataClass | |
self.checkpoint.iteration = self.iteration | |
self.checkpoint.gp_models = self.gp_models | |
self.checkpoint.indices = self.indices | |
self.checkpoint.save_checkpoint() | |
def execute_al_loop(self): | |
self.iteration += 1 | |
if self.iteration == len(self.batch_size_list): | |
print("FINISHED: all AL Iterations over !") | |
return | |
# indices | |
train_indices, heldout_indices = self.indices[self.iteration]['train'], self.indices[self.iteration]['heldout'] | |
# make heldout set predictions based on previous iter gp model | |
predictions_heldout = self.gp_models[self.iteration-1].predict(self.x_train[heldout_indices]) | |
# run acq_fn on heldout set to get new_trainset_indicies | |
num_toselect_curr_iter = self.batch_size_list[self.iteration]# number of datapoints to pick in the current iteration | |
new_trainset_indices = self.acq_func(num_toselect_curr_iter, heldout_indices, predictions_heldout) | |
print(new_trainset_indices, heldout_indices) | |
assert set(new_trainset_indices).issubset(set(heldout_indices)) | |
# remove new_trainset_indices from heldout_indices | |
heldout_indices = np.array(list(set(heldout_indices).difference(set(new_trainset_indices)))) | |
# add new_trainset_indices to train_indices | |
train_indices = np.array(list(set(new_trainset_indices).union(set(train_indices)))) | |
# save metadata to class variables | |
self.indices[self.iteration]['train'], self.indices[self.iteration]['heldout'] = train_indices, heldout_indices | |
# If model fails mid way, start from here.... | |
# train new gp | |
x_train_curr_iteration = self.x_train[train_indices] | |
y_train_curr_iteration = self.y_train[train_indices] | |
self.gp_models[self.iteration].initialise(x_train = x_train_curr_iteration, y_train = y_train_curr_iteration) | |
self.gp_models[self.iteration].train() | |
# get test predictions from previous iteration | |
self.make_test_predictions() | |
if self.iteration+1 < len(self.batch_size_list): | |
print('copying data to next iteration') | |
# copy current indices to next index list | |
self.indices[self.iteration+1]['train'] = self.indices[self.iteration]['train'] | |
self.indices[self.iteration+1]['heldout'] = self.indices[self.iteration]['heldout'] | |
self.indices[self.iteration+1]['test'] = self.indices[self.iteration]['test'] | |
# save metadata to ALMetadataClass | |
self.checkpoint.iteration = self.iteration | |
self.checkpoint.gp_models = self.gp_models | |
self.checkpoint.indices = self.indices | |
self.checkpoint.save_checkpoint() | |
def make_test_predictions(self): | |
test_indices = self.indices[self.iteration]['test'] | |
test_predictions = self.gp_models[self.iteration].predict(self.x_train[test_indices]) | |
print(self.iteration, 'MSE', mean_squared_error(test_predictions.mean, self.y_train[test_indices]), 'eV') | |
def acq_func_f(num_toselect, heldout_indices, predictions_heldout): | |
# sort predictions based on uncertainty | |
stddev = predictions_heldout.stddev.numpy() | |
sort_idxs = np.argsort(stddev) | |
top_half_most_uncertain = sort_idxs[-len(heldout_indices)//2:] | |
selected_indices = np.random.choice(top_half_most_uncertain, size=num_toselect, replace=False) | |
print(type(heldout_indices), type(selected_indices), selected_indices) | |
return heldout_indices[selected_indices] | |
# gp = GenericGP(likelihood = gpytorch.likelihoods.GaussianLikelihood, | |
# model = ExactGPModel, | |
# n_iters=1000) | |
# al = GenericAL(acq_func_f, [10,10,10,10], 10, gp, X_train_, y_train_) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment