Skip to content

Instantly share code, notes, and snippets.

@cycyyy
Created October 20, 2022 04:14
Show Gist options
  • Save cycyyy/db2cd361bb96f275fca4ad11595060e6 to your computer and use it in GitHub Desktop.
Save cycyyy/db2cd361bb96f275fca4ad11595060e6 to your computer and use it in GitHub Desktop.
tiling_test.py
import numpy as np
from jax import random
from neural_tangents import stax
random_key = random.PRNGKey(42)
SAMPLE_SIZE = 100
BATCH_SIZE = 25
def get_mlp_kernel_fn():
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512), stax.Relu(),
stax.Dense(512), stax.Relu(),
stax.Dense(1)
)
return kernel_fn
def get_conv_kernel_fn():
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Conv(128, (3, 3)),
stax.Relu(),
stax.Conv(256, (3, 3)),
stax.Relu(),
stax.Conv(512, (3, 3)),
stax.Flatten(),
stax.Dense(2)
)
return kernel_fn
def get_kernel(x1, x2, kernel_fn):
return kernel_fn(x1, x2, ('nngp', 'ntk'))
def get_kernel_batch(x1, x2, kernel_fn, x1_batch, x2_batch):
if len(x1) % x1_batch != 0 or len(x2) % x2_batch != 0:
raise NotImplementedError(
"Not support sample batch size x1:%d-%d x2:%d-%d" % (len(x1), x1_batch, len(x2), x2_batch))
kernel_nngp = np.zeros((len(x1), len(x2)))
kernel_ntk = np.zeros((len(x1), len(x2)))
for i in range(0, len(x1), x1_batch):
for j in range(0, len(x2), x2_batch):
x1_tile = x1[i:i + x1_batch]
x2_tile = x2[j:j + x2_batch]
nngp_tile, ntk_tile = kernel_fn(x1_tile, x2_tile, ('nngp', 'ntk'))
kernel_nngp[i:i + x1_batch, j:j + x2_batch] = nngp_tile
kernel_ntk[i:i + x1_batch, j:j + x2_batch] = ntk_tile
return kernel_nngp, kernel_ntk
def test_mlp():
print("test mlp")
samples = random.normal(random_key, (SAMPLE_SIZE, 4))
kernel_fn = get_mlp_kernel_fn()
nngp_gt, ntk_gt = get_kernel(samples, samples, kernel_fn)
nngp_tile, ntk_tile = get_kernel_batch(
samples, samples, kernel_fn, BATCH_SIZE, BATCH_SIZE)
assert(np.allclose(nngp_gt, nngp_tile))
assert(np.allclose(ntk_gt, ntk_tile))
def test_conv():
print("test conv")
samples = random.normal(random_key, (SAMPLE_SIZE, 16, 16, 3))
kernel_fn = get_conv_kernel_fn()
nngp_gt, ntk_gt = get_kernel(samples, samples, kernel_fn)
nngp_tile, ntk_tile = get_kernel_batch(
samples, samples, kernel_fn, BATCH_SIZE, BATCH_SIZE)
assert(np.allclose(nngp_gt, nngp_tile))
assert(np.allclose(ntk_gt, ntk_tile))
test_mlp()
test_conv()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment