Skip to content

Instantly share code, notes, and snippets.

@daskol
Created December 22, 2022 09:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daskol/ef1cfd8fba2b6b143d5df77e6c2ce9c7 to your computer and use it in GitHub Desktop.
Save daskol/ef1cfd8fba2b6b143d5df77e6c2ce9c7 to your computer and use it in GitHub Desktop.
Non-jittable initialization of model in JAX/FLAX
from typing import Any, Callable
import flax.linen as nn
import jax
import jax.experimental.host_callback
import jax.numpy as jnp
from flax.linen.initializers import delta_orthogonal, variance_scaling
from phase_diagram import phase_boundary
InitializerTy = Callable[[Any, tuple[int, ...], Any], jax.Array]
class CNN(nn.Module):
"""Class CNN implements a model consisting of plain convoluation layers.
The only difference that this models is designed to be ultra-deep with
specific initialization procedure: delta-orthogonal initialization.
"""
depth: int = 16
features: int = 10
channels: int = 32
kernel_size: int = 3
dtype: Any = jnp.float32
kernel_init: InitializerTy = delta_orthogonal
bias_init: InitializerTy = delta_orthogonal
def setup(self):
# Make fake (empty) `batch_stats` subtree in order to resuse common
# pipeline.
self.variable('batch_stats', 'mean', lambda _: (), ())
self.variable('batch_stats', 'var', lambda _: (), ())
# Estimate critical scales for weights and biases.
# NOTE(@daskol): We use experimental feature here.
# self.varw, self.varb = jax.experimental.host_callback.call(
# callback_func=phase_boundary,
# arg=1 / self.depth,
# result_shape=jax.ShapeDtypeStruct((2, ), jnp.float32))
self.varw, self.varb = phase_boundary(1 / self.depth)
# We need LeCun initialization for kernels and biases.
def default_init(var):
return variance_scaling(var, 'fan_in', 'normal')
def bias_init(var):
def init_fn(key, shape, dtype):
return jnp.sqrt(var) * jax.random.normal(key, shape, dtype)
return init_fn
# Network input: increase channels.
layer = nn.Conv(kernel_size=(self.kernel_size, self.kernel_size),
features=self.channels,
padding='SAME',
use_bias=True,
kernel_init=default_init(self.varw),
bias_init=bias_init(self.varb),
dtype=self.dtype)
# Network input: reduce spacial dims.
layers = [layer]
for _ in range(2):
layer = nn.Conv(features=self.channels,
strides=2,
kernel_size=(self.kernel_size, self.kernel_size),
padding='SAME',
use_bias=True,
kernel_init=default_init(self.varw),
bias_init=bias_init(self.varb),
dtype=self.dtype)
layers.append(layer)
self.input = layers
# Network body.
layers = []
for i in range(self.depth):
# TODO(@daskol): Fix padding. Implement circular_padding func?
layer = nn.Conv(kernel_size=(self.kernel_size, self.kernel_size),
features=self.channels,
padding='CIRCULAR',
use_bias=True,
kernel_init=self.kernel_init(self.varw),
bias_init=bias_init(self.varb),
dtype=self.dtype)
layers.append(layer)
self.conv = layers
# Neural Network head.
self.head = nn.Dense(features=self.features)
def __call__(self, xs: jax.Array, train: bool = True) -> jax.Array:
return self.predict(xs)
def predict(self, xs: jax.Array, train: bool = True) -> jax.Array:
# Subsample and increase number of feature maps on input.
for layer in self.input:
xs = layer(xs)
xs = nn.tanh(xs)
# Apply convolutions in the body of network.
for layer in self.conv:
xs = layer(xs)
xs = nn.tanh(xs)
# Finally, apply head of neural network (pooling + conv).
ys = xs.mean(axis=(1, 2))
ys = self.head(ys)
return ys
def predict_proba(self, xs: jax.Array) -> jax.Array:
return nn.softmax(self.predit(xs))
import jax
import jax.experimental.host_callback
import jax.numpy as jnp
from numpy.testing import assert_array_almost_equal
from cnn import CNN, phase_boundary
def test_variance_estimation():
depth = 16
expected = phase_boundary(1 / depth)
actual = jax.experimental.host_callback.call(
callback_func=phase_boundary,
arg=1 / depth,
result_shape=jax.ShapeDtypeStruct((2, ), jnp.float32))
assert_array_almost_equal(expected, actual)
def test_cnn():
model = CNN(depth=1)
key = jax.random.PRNGKey(42)
inp = jnp.empty((5, 28, 28, 1))
out, state = jax.jit(model.init_with_output)(key, inp, train=False)
assert out.shape == (5, 10)
assert 'batch_stats' in state
#!/usr/bin/env python3
from typing import Callable
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import matplotlib.pyplot as plt
from scipy.integrate import quad
def phase_boundary(qs: jax.Array,
act_fn: Callable[..., jax.Array] = jax.nn.tanh,
chi: float = 1.0) -> tuple[jax.Array, jax.Array]:
"""Function phase_boundary calculates points on boundary between stable and
unstable phases. Points are variance of kernel weights and variance of bias
weights on initialization.
"""
# TODO(@daskol): We try to implement everything in pure JAX with
# vectorization out of box but it turns out that there is not adaptive
# integration routined (scipy.integrate.quad) in JAX. So, we fall back to
# SciPy.
def int_magnitude(qs: jax.Array) -> jax.Array:
@jax.jit
def fn(xs, qs):
ps = jsp.stats.norm.pdf(xs)
ys = act_fn(jnp.sqrt(qs) * xs)
return ps * ys**2
return jnp.array([quad(fn, -jnp.inf, jnp.inf, args=q)[0] for q in qs])
def int_correlation(qs: jax.Array) -> jax.Array:
@jax.jit
def fn(xs, qs):
ys = jax.grad(act_fn)(jnp.sqrt(qs) * xs)
ps = jsp.stats.norm.pdf(xs)
res = ps * ys**2
return res
return jnp.array([quad(fn, -jnp.inf, jnp.inf, args=q)[0] for q in qs])
qs = jnp.asarray(qs)
if qs.ndim == 0:
return phase_boundary(qs[None], act_fn, chi).squeeze()
points = jnp.empty((2, ) + qs.shape)
points = points.at[0, ...].set(chi / int_correlation(qs))
points = points.at[1, ...].set(qs - points[0, ...] * int_magnitude(qs))
return points
if __name__ == '__main__':
# Calculate boundary.
jax.config.enable_64 = True
qs = jnp.linspace(0, 100, 200)
xs, ys = phase_boundary(qs)
# Plot boundary and phases.
fig, ax = plt.subplots()
ax.plot(xs, ys, color='black', label='boundary')
ax.fill_between(xs, ys, ys.max(), facecolor='red', alpha=0.2,
label='stable')
ax.fill_between(xs, ys, 0, facecolor='blue', alpha=0.2, label='unstable')
ax.grid()
ax.legend()
ax.set_xlabel(r'\sigma^2_w')
ax.set_ylabel(r'\sigma^2_b')
plt.savefig('phase-diagram.png')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment