Skip to content

Instantly share code, notes, and snippets.

@Sixzero
Created January 4, 2023 12:20
Show Gist options
  • Save Sixzero/3312071709aadc9e7e6fcc1290cfd58a to your computer and use it in GitHub Desktop.
Save Sixzero/3312071709aadc9e7e6fcc1290cfd58a to your computer and use it in GitHub Desktop.
haiku mlp with custom data
#%%
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A minimal MNIST classifier example modified."""
# from pydiamonds.python_helpers import activate_ipython_magics
# activate_ipython_magics()
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')
import tensorflow_datasets as tfds
from typing import NamedTuple, Iterator
from absl import app
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
from pytwin.solving_td_x import load_raw_data_xy
NUM_CLASSES = 10 # MNIST has 10 classes (hand-written digits).
class Batch(NamedTuple):
image: np.ndarray # [B, H, W, 1]
label: np.ndarray # [B]
class TrainingState(NamedTuple):
params: hk.Params
avg_params: hk.Params
opt_state: optax.OptState
def create_datapipe_from_xy(
x,y,
*,
shuffle: bool,
batch_size: int,
) -> Iterator[Batch]:
"""Loads the MNIST dataset."""
# ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
ds = tf.data.Dataset.from_tensor_slices((x,y)).repeat()
if shuffle:
ds = ds.shuffle(10 * batch_size, seed=0)
ds = ds.batch(batch_size)
ds = ds.map(lambda x,y: Batch(x, y))
return iter(tfds.as_numpy(ds))
def load_my_dataset(dataname: str, **kw):
x,y = load_raw_data_xy(dataname)
return create_datapipe_from_xy(x,y, **kw)
def net_fn(images: jnp.ndarray) -> jnp.ndarray:
"""Standard LeNet-300-100 MLP network."""
x = images.astype(jnp.float32)
mlp = hk.Sequential([
hk.Flatten(),
hk.Linear(1000), jax.nn.leaky_relu,
hk.Linear(500), jax.nn.sigmoid,
hk.Linear(100), jax.nn.relu,
hk.Linear(NUM_CLASSES),
])
return mlp(x)
def main(_):
# First, make the network and optimiser.
network = hk.without_apply_rng(hk.transform(net_fn))
optimiser = optax.adam(1e-2)
def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
"""Cross-entropy classification loss, regularised by L2 weight decay."""
batch_size, *_ = batch.image.shape
y = network.apply(params, batch.image)
Y = batch.label
l2_regulariser = sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
return jnp.mean((y-Y)**2) / batch_size + 1e-5 * l2_regulariser
@jax.jit
def evaluate(params: hk.Params, batch: Batch) -> jnp.ndarray:
"""Evaluation metric (classification accuracy)."""
y = network.apply(params, batch.image)
Y = batch.label
return jnp.mean((y-Y)**2)
@jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
"""Learning rule (stochastic gradient descent)."""
grads = jax.grad(loss)(state.params, batch)
updates, opt_state = optimiser.update(grads, state.opt_state)
params = optax.apply_updates(state.params, updates)
# Compute avg_params, the exponential moving average of the "live" params.
# We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
avg_params = optax.incremental_update(
params, state.avg_params, step_size=0.001)
return TrainingState(params, avg_params, opt_state)
dataset = 'raw_data_3_XY.jld2'
dataset = 'raw_data.4.jld2'
# Make datasets.
train_dataset = load_my_dataset(dataset, shuffle=True, batch_size=100)
eval_datasets = {
split: load_my_dataset(split, shuffle=False, batch_size=1_000)
for split in (dataset, )
}
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# Initialise network and optimiser; note we draw an input to get shapes.
print(next(train_dataset).image.shape)
initial_params = network.init(
jax.random.PRNGKey(seed=0), next(train_dataset).image)
initial_opt_state = optimiser.init(initial_params)
state = TrainingState(initial_params, initial_params, initial_opt_state)
# Training & evaluation loop.
for step in range(6110):
if step % 100 == 0 or (step < 200 and step%10==1):
# Periodically evaluate classification accuracy on train & test sets.
# Note that each evaluation is only on a (large) batch.
for split, dataset in eval_datasets.items():
accuracy = np.array(evaluate(state.avg_params, next(dataset))).item()
print(f"{step}.: set T acc {accuracy:.3f}")
# Do SGD on a batch of training examples.
state = update(state, next(train_dataset))
if __name__ == "__main__":
main(None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment