Skip to content

Instantly share code, notes, and snippets.

@maciejkorzepa
Created April 22, 2020 08:48
Show Gist options
  • Save maciejkorzepa/02403aacfd53a92f0e2a202daefd86f6 to your computer and use it in GitHub Desktop.
Save maciejkorzepa/02403aacfd53a92f0e2a202daefd86f6 to your computer and use it in GitHub Desktop.
import os
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/opt/cuda/cuda-10.2'
import jax
from jax.api import jacobian, grad, jvp, vjp
import jax.numpy as jnp
from jax.experimental import stax
f1 = stax.serial(stax.Dense(10), stax.Relu)
f2 = stax.serial(stax.Dense(20), stax.Relu)
f3 = stax.serial(stax.Dense(1))
init_fn, apply_fn = stax.serial(f1, f2, f3)
key = jax.random.PRNGKey(0)
_, params = init_fn(key, (-1, 10))
layers = (
(f1[1], params[0]),
(f2[1], params[1]),
(f3[1], params[2])
)
x = jnp.array(np.random.randn(5, 10))
inputs = [x]
for layer, ps in layers:
output = layer(ps, inputs[-1])
inputs.append(output)
output_grad = np.ones(inputs[-1].shape)
for (layer, ps), inp, out in zip(layers[::-1], inputs[-2::-1], inputs[:0:-1]):
f = lambda p: layer(p, inp)
def delta_vjp_jvp(delta):
def delta_vjp(delta):
return vjp(f, ps)[1](delta)
return jvp(f, (ps,), delta_vjp(delta))[1]
dummy = np.zeros(out.shape)
jjt_layer = jacobian(delta_vjp_jvp)(dummy)
print(jjt_layer.shape)
### calculates partial(f_L)/partial(f_l) for layer l in the next iteration; not used now
vjp_fun = vjp(lambda inp: layer(ps, inp), inp)[1]
output_grad = vjp_fun(output_grad)[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment