Created
December 22, 2022 09:23
-
-
Save daskol/ef1cfd8fba2b6b143d5df77e6c2ce9c7 to your computer and use it in GitHub Desktop.
Non-jittable initialization of model in JAX/FLAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Callable | |
import flax.linen as nn | |
import jax | |
import jax.experimental.host_callback | |
import jax.numpy as jnp | |
from flax.linen.initializers import delta_orthogonal, variance_scaling | |
from phase_diagram import phase_boundary | |
InitializerTy = Callable[[Any, tuple[int, ...], Any], jax.Array] | |
class CNN(nn.Module): | |
"""Class CNN implements a model consisting of plain convoluation layers. | |
The only difference that this models is designed to be ultra-deep with | |
specific initialization procedure: delta-orthogonal initialization. | |
""" | |
depth: int = 16 | |
features: int = 10 | |
channels: int = 32 | |
kernel_size: int = 3 | |
dtype: Any = jnp.float32 | |
kernel_init: InitializerTy = delta_orthogonal | |
bias_init: InitializerTy = delta_orthogonal | |
def setup(self): | |
# Make fake (empty) `batch_stats` subtree in order to resuse common | |
# pipeline. | |
self.variable('batch_stats', 'mean', lambda _: (), ()) | |
self.variable('batch_stats', 'var', lambda _: (), ()) | |
# Estimate critical scales for weights and biases. | |
# NOTE(@daskol): We use experimental feature here. | |
# self.varw, self.varb = jax.experimental.host_callback.call( | |
# callback_func=phase_boundary, | |
# arg=1 / self.depth, | |
# result_shape=jax.ShapeDtypeStruct((2, ), jnp.float32)) | |
self.varw, self.varb = phase_boundary(1 / self.depth) | |
# We need LeCun initialization for kernels and biases. | |
def default_init(var): | |
return variance_scaling(var, 'fan_in', 'normal') | |
def bias_init(var): | |
def init_fn(key, shape, dtype): | |
return jnp.sqrt(var) * jax.random.normal(key, shape, dtype) | |
return init_fn | |
# Network input: increase channels. | |
layer = nn.Conv(kernel_size=(self.kernel_size, self.kernel_size), | |
features=self.channels, | |
padding='SAME', | |
use_bias=True, | |
kernel_init=default_init(self.varw), | |
bias_init=bias_init(self.varb), | |
dtype=self.dtype) | |
# Network input: reduce spacial dims. | |
layers = [layer] | |
for _ in range(2): | |
layer = nn.Conv(features=self.channels, | |
strides=2, | |
kernel_size=(self.kernel_size, self.kernel_size), | |
padding='SAME', | |
use_bias=True, | |
kernel_init=default_init(self.varw), | |
bias_init=bias_init(self.varb), | |
dtype=self.dtype) | |
layers.append(layer) | |
self.input = layers | |
# Network body. | |
layers = [] | |
for i in range(self.depth): | |
# TODO(@daskol): Fix padding. Implement circular_padding func? | |
layer = nn.Conv(kernel_size=(self.kernel_size, self.kernel_size), | |
features=self.channels, | |
padding='CIRCULAR', | |
use_bias=True, | |
kernel_init=self.kernel_init(self.varw), | |
bias_init=bias_init(self.varb), | |
dtype=self.dtype) | |
layers.append(layer) | |
self.conv = layers | |
# Neural Network head. | |
self.head = nn.Dense(features=self.features) | |
def __call__(self, xs: jax.Array, train: bool = True) -> jax.Array: | |
return self.predict(xs) | |
def predict(self, xs: jax.Array, train: bool = True) -> jax.Array: | |
# Subsample and increase number of feature maps on input. | |
for layer in self.input: | |
xs = layer(xs) | |
xs = nn.tanh(xs) | |
# Apply convolutions in the body of network. | |
for layer in self.conv: | |
xs = layer(xs) | |
xs = nn.tanh(xs) | |
# Finally, apply head of neural network (pooling + conv). | |
ys = xs.mean(axis=(1, 2)) | |
ys = self.head(ys) | |
return ys | |
def predict_proba(self, xs: jax.Array) -> jax.Array: | |
return nn.softmax(self.predit(xs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.experimental.host_callback | |
import jax.numpy as jnp | |
from numpy.testing import assert_array_almost_equal | |
from cnn import CNN, phase_boundary | |
def test_variance_estimation(): | |
depth = 16 | |
expected = phase_boundary(1 / depth) | |
actual = jax.experimental.host_callback.call( | |
callback_func=phase_boundary, | |
arg=1 / depth, | |
result_shape=jax.ShapeDtypeStruct((2, ), jnp.float32)) | |
assert_array_almost_equal(expected, actual) | |
def test_cnn(): | |
model = CNN(depth=1) | |
key = jax.random.PRNGKey(42) | |
inp = jnp.empty((5, 28, 28, 1)) | |
out, state = jax.jit(model.init_with_output)(key, inp, train=False) | |
assert out.shape == (5, 10) | |
assert 'batch_stats' in state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from typing import Callable | |
import jax | |
import jax.numpy as jnp | |
import jax.scipy as jsp | |
import matplotlib.pyplot as plt | |
from scipy.integrate import quad | |
def phase_boundary(qs: jax.Array, | |
act_fn: Callable[..., jax.Array] = jax.nn.tanh, | |
chi: float = 1.0) -> tuple[jax.Array, jax.Array]: | |
"""Function phase_boundary calculates points on boundary between stable and | |
unstable phases. Points are variance of kernel weights and variance of bias | |
weights on initialization. | |
""" | |
# TODO(@daskol): We try to implement everything in pure JAX with | |
# vectorization out of box but it turns out that there is not adaptive | |
# integration routined (scipy.integrate.quad) in JAX. So, we fall back to | |
# SciPy. | |
def int_magnitude(qs: jax.Array) -> jax.Array: | |
@jax.jit | |
def fn(xs, qs): | |
ps = jsp.stats.norm.pdf(xs) | |
ys = act_fn(jnp.sqrt(qs) * xs) | |
return ps * ys**2 | |
return jnp.array([quad(fn, -jnp.inf, jnp.inf, args=q)[0] for q in qs]) | |
def int_correlation(qs: jax.Array) -> jax.Array: | |
@jax.jit | |
def fn(xs, qs): | |
ys = jax.grad(act_fn)(jnp.sqrt(qs) * xs) | |
ps = jsp.stats.norm.pdf(xs) | |
res = ps * ys**2 | |
return res | |
return jnp.array([quad(fn, -jnp.inf, jnp.inf, args=q)[0] for q in qs]) | |
qs = jnp.asarray(qs) | |
if qs.ndim == 0: | |
return phase_boundary(qs[None], act_fn, chi).squeeze() | |
points = jnp.empty((2, ) + qs.shape) | |
points = points.at[0, ...].set(chi / int_correlation(qs)) | |
points = points.at[1, ...].set(qs - points[0, ...] * int_magnitude(qs)) | |
return points | |
if __name__ == '__main__': | |
# Calculate boundary. | |
jax.config.enable_64 = True | |
qs = jnp.linspace(0, 100, 200) | |
xs, ys = phase_boundary(qs) | |
# Plot boundary and phases. | |
fig, ax = plt.subplots() | |
ax.plot(xs, ys, color='black', label='boundary') | |
ax.fill_between(xs, ys, ys.max(), facecolor='red', alpha=0.2, | |
label='stable') | |
ax.fill_between(xs, ys, 0, facecolor='blue', alpha=0.2, label='unstable') | |
ax.grid() | |
ax.legend() | |
ax.set_xlabel(r'\sigma^2_w') | |
ax.set_ylabel(r'\sigma^2_b') | |
plt.savefig('phase-diagram.png') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment