Created
April 14, 2022 15:15
-
-
Save daskol/95879e5a04d34ce0c2037f471ee9f82f to your computer and use it in GitHub Desktop.
Common routines for deep learning with JAX/FLAX.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Module common defines common routines and training utils for JAX/Flax | |
environment. | |
""" | |
import logging | |
import flax | |
import jax | |
import jax.numpy as jnp | |
from os import urandom | |
from struct import unpack | |
from typing import Callable, Optional | |
from flax.training import train_state | |
def make_rng(seed: Optional[int] = None) -> jax.random.PRNGKey: | |
"""Function make_rng initializes state of PRNG either with specified seed | |
or with random bits sampled from random number source devices. So, this | |
function has side effects. | |
:param seed: PRNG seed. | |
:return: State of PRNG. | |
""" | |
if seed is None: | |
seed, = unpack('Q', urandom(8)) | |
logging.info('sample from /dev/urandom seed %d', seed) | |
return jax.random.PRNGKey(seed) | |
@jax.jit | |
def classify(y_pred: jnp.ndarray): | |
return y_pred.argmax(-1) | |
@jax.jit | |
def forecast(y_pred: jnp.ndarray): | |
return y_pred[..., 0] | |
@jax.jit | |
def loss_entropy(y_true: jnp.ndarray, y_pred: jnp.ndarray): | |
assert y_pred.ndim == 2, 'Logits are expected.' | |
ps = jax.nn.log_softmax(y_pred) | |
ts = jnp.take_along_axis(ps, y_true[..., None], axis=-1) | |
return -ts.mean() | |
@jax.jit | |
def loss_l2(y_true: jnp.ndarray, y_pred: jnp.ndarray): | |
return jnp.mean((y_true[..., 0] - y_pred) ** 2) | |
class TrainState(train_state.TrainState): | |
"""Class TrainState is derived from flax's TrainState in order to | |
initialize conveniently as late as possible. The main idea is that it | |
initializes state greedily with explicitely specified params or initializes | |
lazily with specified init_fn method. | |
>>> class Model(nn.Module): | |
>>> @nn.compact | |
>>> def __call__(self, x): | |
>>> return nn.Dense(10)(x) | |
>>> | |
>>> tx = sgd() | |
>>> model = Model() | |
>>> state = TrainState.create(apply_fn=model.apply, | |
>>> init_fn=model.init, | |
>>> tx=tx) | |
>>> # Load data, prepare dataset or collate records. | |
>>> state = state.init(rng, input) | |
>>> # Iterate over dataset and train model. | |
""" | |
init_fn: Callable = flax.struct.field(pytree_node=False) | |
loss_fn: Callable = flax.struct.field(pytree_node=False) | |
predict_fn: Callable = flax.struct.field(pytree_node=False) | |
def init(self, rng, inp, reset=False): | |
# If model params are not empty and reset is not forced then exit. | |
if not reset and self.params: | |
return self | |
step = 0 if reset else self.step | |
params = self.init_fn(rng, inp) | |
opt_state = self.tx.init(params) | |
return self.replace(step=step, params=params, opt_state=opt_state) | |
@classmethod | |
def create(cls, *, apply_fn, tx, init_fn=None, params=None, **kwargs): | |
"""Class method create creates a new instance with `step=0` and | |
initialized `opt_state` if parameters are not empty. | |
""" | |
if init_fn is None or not params: | |
raise ValueError('Either init_fn or params should be specified.') | |
# Initialize optimizer state if there are parameters. Otherwise, | |
# proceed with sentinel values. | |
if params: | |
opt_state = tx.init(params) | |
else: | |
params = {} | |
opt_state = {} | |
return cls(step=0, | |
apply_fn=apply_fn, | |
init_fn=jax.jit(init_fn), # Force JAX to trace function. | |
params=params, | |
tx=tx, | |
opt_state=opt_state, | |
**kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment