Skip to content

Instantly share code, notes, and snippets.

@maciejkorzepa
Last active April 17, 2020 16:38
Show Gist options
  • Save maciejkorzepa/23bbb4443c2fcb93927eb36d9ef2091c to your computer and use it in GitHub Desktop.
Save maciejkorzepa/23bbb4443c2fcb93927eb36d9ef2091c to your computer and use it in GitHub Desktop.
code to reproduce slow empirical NTK kernel computation in neural-tangents
import numpy as np
import jax
from jax.experimental import stax
import neural_tangents as nt
num_base_out_chan = 32
init_fn, apply_fn = stax.serial(
stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(2, 2), padding='SAME'), stax.Relu,
stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),
stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.Conv(num_base_out_chan, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),
stax.Conv(num_base_out_chan * 2, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.Conv(num_base_out_chan * 2, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),
stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.Conv(num_base_out_chan * 4, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.MaxPool(window_shape=(3, 3), strides=(2, 2), padding='SAME'),
stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.Conv(num_base_out_chan * 8, filter_shape=(3, 3), strides=(1, 1), padding='SAME'), stax.Relu,
stax.FanOut(2),
stax.parallel(
stax.serial(stax.MaxPool(window_shape=(16, 16)), stax.Flatten),
stax.serial(stax.AvgPool(window_shape=(16, 16)), stax.Flatten)
),
stax.FanInConcat(),
stax.Dense(1)
)
d = 512
key = jax.random.PRNGKey(0)
_, params = init_fn(key, (-1, d, d, 3))
x_train = np.random.randn(100, d, d, 3).astype(np.float32)
ntk = nt.batch(jax.jit(nt.empirical_ntk_fn(apply_fn)), batch_size=10, device_count=1)
kernel = ntk(x_train, None, params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment