Last active
April 29, 2024 08:48
-
-
Save DrJonnyT/c946044591fb4ce922b0f5e7fd0f047a to your computer and use it in GitHub Desktop.
Tensorflow Pytorch comparison using neuralprocesses library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Script to compare the tensorflow and pytorch implementations of neuralprocesses. | |
The training losses printed at the end should be very similar, the differences are | |
thought to be due to different random seeds. | |
This script currently works for gnp, agnp, and convgnp architectures (defined | |
on L93 & 94). | |
""" | |
import tensorflow as tf | |
import torch | |
import numpy as np | |
import neuralprocesses.torch as nps_torch | |
import neuralprocesses.tensorflow as nps_tf | |
def x_to_y(x): | |
"""Dummy function to make learnable y data from random x data""" | |
shape = x.shape | |
y = torch.randn(shape[0],2,shape[2]) | |
y[:, 0, :] *= 2 | |
y[:, 1, :] *= 3 | |
y = y + torch.randn_like(y)*0.1 | |
return y | |
def copy_weights_and_biases(model_torch, model_tf): | |
"""Copy weights from torch model to tf model""" | |
weights_tf = model_tf.get_weights() | |
weights_torch = [param.detach().numpy() for param in model_torch.parameters()] | |
for i in range(len(weights_torch)): | |
if weights_tf[i].shape == weights_torch[i].shape: | |
weights_tf[i] = weights_torch[i] | |
elif weights_tf[i].shape == weights_torch[i].T.shape: | |
weights_tf[i] = weights_torch[i].T | |
model_tf.set_weights(weights_tf) | |
print("Weights and biases copied successfully from PyTorch model to TensorFlow model") | |
def compare_models(model_torch, model_tf): | |
"""Check that the weights and biases in a tf and torch model are all the same""" | |
# Convert PyTorch model to state_dict (dictionary object) | |
pytorch_state_dict = model_torch.state_dict() | |
# Get TensorFlow model variables | |
tensorflow_variables = model_tf.trainable_variables | |
# Check if the number of layers are the same | |
if len(pytorch_state_dict) != len(tensorflow_variables): | |
print("The models have a different number of layers.") | |
return False | |
# Iterate over PyTorch model parameters | |
for item, ((name, param), tf_var) in enumerate(zip(pytorch_state_dict.items(), tensorflow_variables)): | |
# Convert PyTorch tensor to numpy array | |
pytorch_param = param.detach().numpy() | |
# Get corresponding TensorFlow variable | |
tensorflow_param = tf_var.numpy() | |
# Check if the shapes are the same | |
if pytorch_param.shape != tensorflow_param.shape: | |
pytorch_param = pytorch_param.transpose() | |
if pytorch_param.shape != tensorflow_param.shape: | |
print(f'Difference found in layer: {name}. Different shapes.') | |
return False | |
# Check if the weights are the same | |
if not np.allclose(pytorch_param, tensorflow_param, atol=1e-6): | |
print(f'Difference found in layer: {name}. Weights are not the same.') | |
return False | |
print('All layers have the same shape and weights.') | |
return True | |
# %% | |
# Make some data | |
num_batches = 8 | |
xc_list_torch, yc_list_torch, xt_list_torch, yt_list_torch = [],[],[],[] | |
for batch in range(num_batches): | |
xc = torch.randn(16, 1, 10) | |
xt = torch.randn(16, 1, 15) | |
xc_list_torch.append(xc) # Context inputs | |
xt_list_torch.append(xt) # Target inputs | |
yc_list_torch.append(x_to_y(xc)) # Context outputs | |
yt_list_torch.append(x_to_y(xt)) # Target output | |
# Construct models | |
gnp_torch = nps_torch.construct_agnp(dim_x=1, dim_y=2, likelihood="het") | |
gnp_tf = nps_tf.construct_agnp(dim_x=1, dim_y=2, likelihood="het") | |
# Put some data through the tf model first | |
_ = gnp_tf( | |
tf.convert_to_tensor(xc_list_torch[0].numpy()), | |
tf.convert_to_tensor(yc_list_torch[0].numpy()), | |
tf.convert_to_tensor(xt_list_torch[0].numpy()) | |
) | |
# SGD Optimizers that I have tested to be equivalent | |
opt_torch = torch.optim.SGD(gnp_torch.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False) | |
opt_tf = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0, nesterov=False) | |
# %% | |
# Copy weights and biases | |
copy_weights_and_biases(gnp_torch,gnp_tf) | |
assert compare_models(gnp_torch,gnp_tf) | |
# %% | |
# Training loop of 5 actual epochs, with one warmup epoch at the start to test | |
# the losses of the untrained models to check they are similar from the inital weights | |
num_epochs = 5 | |
epochs_loss_torch = [] | |
epochs_loss_tf = [] | |
for epoch in range(num_epochs+1): | |
this_epoch_loss_torch = [] | |
this_epoch_loss_tf = [] | |
for batch in range(num_batches): | |
# Torch version | |
xc_torch = xc_list_torch[batch] | |
yc_torch = yc_list_torch[batch] | |
xt_torch = xt_list_torch[batch] | |
yt_torch = yt_list_torch[batch] | |
loss_torch = -torch.mean(nps_torch.loglik(gnp_torch, xc_torch, yc_torch, xt_torch, yt_torch, normalise=True)) | |
if epoch > 0: | |
opt_torch.zero_grad(set_to_none=True) | |
loss_torch.backward() | |
opt_torch.step() | |
this_epoch_loss_torch.append(loss_torch.detach().numpy()) | |
# Tensorflow version with the same data | |
xc_tf = tf.convert_to_tensor(xc_torch.numpy()) | |
yc_tf = tf.convert_to_tensor(yc_torch.numpy()) | |
xt_tf = tf.convert_to_tensor(xt_torch.numpy()) | |
yt_tf = tf.convert_to_tensor(yt_torch.numpy()) | |
with tf.GradientTape() as tape: | |
# Compute the loss | |
loss_tf = -tf.reduce_mean(nps_tf.loglik(gnp_tf, xc_tf, yc_tf, xt_tf, yt_tf, normalise=True)) | |
if epoch > 0: | |
gradients = tape.gradient(loss_tf, gnp_tf.trainable_variables) | |
opt_tf.apply_gradients(zip(gradients, gnp_tf.trainable_variables)) | |
this_epoch_loss_tf.append(loss_tf.numpy()) | |
# Collate the losses per epoch | |
epochs_loss_torch.append(np.mean(this_epoch_loss_torch).round(3)) | |
epochs_loss_tf.append(np.mean(this_epoch_loss_tf).round(3)) | |
print(f"Torch losses:\n{epochs_loss_torch}") | |
print(f"TF losses:\n{epochs_loss_tf}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment