Skip to content

Instantly share code, notes, and snippets.

@DrJonnyT
Last active April 29, 2024 08:48
Show Gist options
  • Save DrJonnyT/c946044591fb4ce922b0f5e7fd0f047a to your computer and use it in GitHub Desktop.
Save DrJonnyT/c946044591fb4ce922b0f5e7fd0f047a to your computer and use it in GitHub Desktop.
Tensorflow Pytorch comparison using neuralprocesses library
"""Script to compare the tensorflow and pytorch implementations of neuralprocesses.
The training losses printed at the end should be very similar, the differences are
thought to be due to different random seeds.
This script currently works for gnp, agnp, and convgnp architectures (defined
on L93 & 94).
"""
import tensorflow as tf
import torch
import numpy as np
import neuralprocesses.torch as nps_torch
import neuralprocesses.tensorflow as nps_tf
def x_to_y(x):
"""Dummy function to make learnable y data from random x data"""
shape = x.shape
y = torch.randn(shape[0],2,shape[2])
y[:, 0, :] *= 2
y[:, 1, :] *= 3
y = y + torch.randn_like(y)*0.1
return y
def copy_weights_and_biases(model_torch, model_tf):
"""Copy weights from torch model to tf model"""
weights_tf = model_tf.get_weights()
weights_torch = [param.detach().numpy() for param in model_torch.parameters()]
for i in range(len(weights_torch)):
if weights_tf[i].shape == weights_torch[i].shape:
weights_tf[i] = weights_torch[i]
elif weights_tf[i].shape == weights_torch[i].T.shape:
weights_tf[i] = weights_torch[i].T
model_tf.set_weights(weights_tf)
print("Weights and biases copied successfully from PyTorch model to TensorFlow model")
def compare_models(model_torch, model_tf):
"""Check that the weights and biases in a tf and torch model are all the same"""
# Convert PyTorch model to state_dict (dictionary object)
pytorch_state_dict = model_torch.state_dict()
# Get TensorFlow model variables
tensorflow_variables = model_tf.trainable_variables
# Check if the number of layers are the same
if len(pytorch_state_dict) != len(tensorflow_variables):
print("The models have a different number of layers.")
return False
# Iterate over PyTorch model parameters
for item, ((name, param), tf_var) in enumerate(zip(pytorch_state_dict.items(), tensorflow_variables)):
# Convert PyTorch tensor to numpy array
pytorch_param = param.detach().numpy()
# Get corresponding TensorFlow variable
tensorflow_param = tf_var.numpy()
# Check if the shapes are the same
if pytorch_param.shape != tensorflow_param.shape:
pytorch_param = pytorch_param.transpose()
if pytorch_param.shape != tensorflow_param.shape:
print(f'Difference found in layer: {name}. Different shapes.')
return False
# Check if the weights are the same
if not np.allclose(pytorch_param, tensorflow_param, atol=1e-6):
print(f'Difference found in layer: {name}. Weights are not the same.')
return False
print('All layers have the same shape and weights.')
return True
# %%
# Make some data
num_batches = 8
xc_list_torch, yc_list_torch, xt_list_torch, yt_list_torch = [],[],[],[]
for batch in range(num_batches):
xc = torch.randn(16, 1, 10)
xt = torch.randn(16, 1, 15)
xc_list_torch.append(xc) # Context inputs
xt_list_torch.append(xt) # Target inputs
yc_list_torch.append(x_to_y(xc)) # Context outputs
yt_list_torch.append(x_to_y(xt)) # Target output
# Construct models
gnp_torch = nps_torch.construct_agnp(dim_x=1, dim_y=2, likelihood="het")
gnp_tf = nps_tf.construct_agnp(dim_x=1, dim_y=2, likelihood="het")
# Put some data through the tf model first
_ = gnp_tf(
tf.convert_to_tensor(xc_list_torch[0].numpy()),
tf.convert_to_tensor(yc_list_torch[0].numpy()),
tf.convert_to_tensor(xt_list_torch[0].numpy())
)
# SGD Optimizers that I have tested to be equivalent
opt_torch = torch.optim.SGD(gnp_torch.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False)
opt_tf = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0, nesterov=False)
# %%
# Copy weights and biases
copy_weights_and_biases(gnp_torch,gnp_tf)
assert compare_models(gnp_torch,gnp_tf)
# %%
# Training loop of 5 actual epochs, with one warmup epoch at the start to test
# the losses of the untrained models to check they are similar from the inital weights
num_epochs = 5
epochs_loss_torch = []
epochs_loss_tf = []
for epoch in range(num_epochs+1):
this_epoch_loss_torch = []
this_epoch_loss_tf = []
for batch in range(num_batches):
# Torch version
xc_torch = xc_list_torch[batch]
yc_torch = yc_list_torch[batch]
xt_torch = xt_list_torch[batch]
yt_torch = yt_list_torch[batch]
loss_torch = -torch.mean(nps_torch.loglik(gnp_torch, xc_torch, yc_torch, xt_torch, yt_torch, normalise=True))
if epoch > 0:
opt_torch.zero_grad(set_to_none=True)
loss_torch.backward()
opt_torch.step()
this_epoch_loss_torch.append(loss_torch.detach().numpy())
# Tensorflow version with the same data
xc_tf = tf.convert_to_tensor(xc_torch.numpy())
yc_tf = tf.convert_to_tensor(yc_torch.numpy())
xt_tf = tf.convert_to_tensor(xt_torch.numpy())
yt_tf = tf.convert_to_tensor(yt_torch.numpy())
with tf.GradientTape() as tape:
# Compute the loss
loss_tf = -tf.reduce_mean(nps_tf.loglik(gnp_tf, xc_tf, yc_tf, xt_tf, yt_tf, normalise=True))
if epoch > 0:
gradients = tape.gradient(loss_tf, gnp_tf.trainable_variables)
opt_tf.apply_gradients(zip(gradients, gnp_tf.trainable_variables))
this_epoch_loss_tf.append(loss_tf.numpy())
# Collate the losses per epoch
epochs_loss_torch.append(np.mean(this_epoch_loss_torch).round(3))
epochs_loss_tf.append(np.mean(this_epoch_loss_tf).round(3))
print(f"Torch losses:\n{epochs_loss_torch}")
print(f"TF losses:\n{epochs_loss_tf}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment