Skip to content

Instantly share code, notes, and snippets.

@PhilipVinc
Last active October 25, 2022 15:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PhilipVinc/4f852ccbd1355d488d9b23e9ca595a19 to your computer and use it in GitHub Desktop.
Save PhilipVinc/4f852ccbd1355d488d9b23e9ca595a19 to your computer and use it in GitHub Desktop.
crash pennylane
import jax
import jax.numpy as jnp
import numpy as np
import pennylane as qml
phys_qubits = 2
pars_q = np.random.rand(2)
def minimal_circ(params, prng_key=None):
if prng_key is not None:
dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000, prng_key=prng_key)
else:
dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000)
@qml.qnode(dev, interface="jax",diff_method="parameter-shift")
def _measure_operator():
qml.RY(params[0],wires=0)
qml.RY(params[1],wires=1)
op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
return qml.expval(op)
res = _measure_operator()
print("res is: ", res, type(res))
return res
grad_fun = jax.grad(minimal_circ)
# non jitted, they both work
grad_fun(pars_q)
grad_fun(pars_q, jax.random.PRNGKey(0))
# jitted, without PRNKey it works (but is useless as the result will be a constant)
jax.jit(grad_fun)(pars_q)
# jitted with PRNGKey it crashes
jax.jit(grad_fun)(pars_q, jax.random.PRNGKey(0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment