Skip to content

Instantly share code, notes, and snippets.

@Sinestro38

Sinestro38/mnew.py Secret

Created May 22, 2021
Embed
What would you like to do?
Trying to calculate gradients of a joint model
# # Calculating gradients of a joint model
import tensorflow as tf
import pennylane as qml
from pennylane import numpy as np
wires = 1
dev = qml.device('default.qubit', wires=wires)
# ### Parameterized quantum circuit with one RY gate
@qml.qnode(dev, interface='tf')
def gen_circuit(w):
qml.RY(w[0], wires=0)
return qml.expval(qml.PauliZ(0))
# ### Simple single neuron NN
def make_nn():
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=1, activation='sigmoid', input_shape=[1]))
return model
nn = make_nn()
# Input tensors to `nn` must be of shape (1,1) but the output tensor of `gen_circuit` is shape (1,). So to test `nn`, we first reshape the expval tensor evaluated by `gen_circuit()` to (1,1)
nn(np.reshape(gen_circuit(init_gen_weights), (1,1)))
# ### Evaluating gradient of joint model with respect to gen_weights
"""Initializing a random weight to feed into gen_circuit"""
init_gen_weights = np.random.normal(0, 1, 1)
init_gen_weights = tf.convert_to_tensor(init_gen_weights)
with tf.GradientTape() as tape:
gen_exp_val = gen_circuit(init_gen_weights)
gen_exp_val = tf.reshape(gen_exp_val, (1,1))
output = nn(gen_exp_val)
gen_grad = tape.gradient(output, init_gen_weights)
print(gen_grad)
# ^^ returns `None` :(
@Sinestro38

This comment has been minimized.

Copy link
Owner Author

@Sinestro38 Sinestro38 commented May 22, 2021

Fixed! What I learned here is that the second positional input into tape.gradient functions must be a tf.Variable. Here, it returns none because on line 30, init_gen_weights was converted to a tf.Tensor when it should have been converted to a tf.Variable. With that fix, it returns valid gradients!

Before: init_gen_weights = tf.convert_to_tensor(init_gen_weights)
After: init_gen_weights = tf.Variable(init_gen_weights)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment