Skip to content

Instantly share code, notes, and snippets.

@daskol
Created April 14, 2022 15:15
Show Gist options
  • Save daskol/95879e5a04d34ce0c2037f471ee9f82f to your computer and use it in GitHub Desktop.
Save daskol/95879e5a04d34ce0c2037f471ee9f82f to your computer and use it in GitHub Desktop.
Common routines for deep learning with JAX/FLAX.
"""Module common defines common routines and training utils for JAX/Flax
environment.
"""
import logging
import flax
import jax
import jax.numpy as jnp
from os import urandom
from struct import unpack
from typing import Callable, Optional
from flax.training import train_state
def make_rng(seed: Optional[int] = None) -> jax.random.PRNGKey:
"""Function make_rng initializes state of PRNG either with specified seed
or with random bits sampled from random number source devices. So, this
function has side effects.
:param seed: PRNG seed.
:return: State of PRNG.
"""
if seed is None:
seed, = unpack('Q', urandom(8))
logging.info('sample from /dev/urandom seed %d', seed)
return jax.random.PRNGKey(seed)
@jax.jit
def classify(y_pred: jnp.ndarray):
return y_pred.argmax(-1)
@jax.jit
def forecast(y_pred: jnp.ndarray):
return y_pred[..., 0]
@jax.jit
def loss_entropy(y_true: jnp.ndarray, y_pred: jnp.ndarray):
assert y_pred.ndim == 2, 'Logits are expected.'
ps = jax.nn.log_softmax(y_pred)
ts = jnp.take_along_axis(ps, y_true[..., None], axis=-1)
return -ts.mean()
@jax.jit
def loss_l2(y_true: jnp.ndarray, y_pred: jnp.ndarray):
return jnp.mean((y_true[..., 0] - y_pred) ** 2)
class TrainState(train_state.TrainState):
"""Class TrainState is derived from flax's TrainState in order to
initialize conveniently as late as possible. The main idea is that it
initializes state greedily with explicitely specified params or initializes
lazily with specified init_fn method.
>>> class Model(nn.Module):
>>> @nn.compact
>>> def __call__(self, x):
>>> return nn.Dense(10)(x)
>>>
>>> tx = sgd()
>>> model = Model()
>>> state = TrainState.create(apply_fn=model.apply,
>>> init_fn=model.init,
>>> tx=tx)
>>> # Load data, prepare dataset or collate records.
>>> state = state.init(rng, input)
>>> # Iterate over dataset and train model.
"""
init_fn: Callable = flax.struct.field(pytree_node=False)
loss_fn: Callable = flax.struct.field(pytree_node=False)
predict_fn: Callable = flax.struct.field(pytree_node=False)
def init(self, rng, inp, reset=False):
# If model params are not empty and reset is not forced then exit.
if not reset and self.params:
return self
step = 0 if reset else self.step
params = self.init_fn(rng, inp)
opt_state = self.tx.init(params)
return self.replace(step=step, params=params, opt_state=opt_state)
@classmethod
def create(cls, *, apply_fn, tx, init_fn=None, params=None, **kwargs):
"""Class method create creates a new instance with `step=0` and
initialized `opt_state` if parameters are not empty.
"""
if init_fn is None or not params:
raise ValueError('Either init_fn or params should be specified.')
# Initialize optimizer state if there are parameters. Otherwise,
# proceed with sentinel values.
if params:
opt_state = tx.init(params)
else:
params = {}
opt_state = {}
return cls(step=0,
apply_fn=apply_fn,
init_fn=jax.jit(init_fn), # Force JAX to trace function.
params=params,
tx=tx,
opt_state=opt_state,
**kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment