Created
January 4, 2023 12:20
-
-
Save Sixzero/3312071709aadc9e7e6fcc1290cfd58a to your computer and use it in GitHub Desktop.
haiku mlp with custom data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A minimal MNIST classifier example modified.""" | |
# from pydiamonds.python_helpers import activate_ipython_magics | |
# activate_ipython_magics() | |
import tensorflow as tf | |
# Ensure TF does not see GPU and grab all GPU memory. | |
tf.config.set_visible_devices([], device_type='GPU') | |
import tensorflow_datasets as tfds | |
from typing import NamedTuple, Iterator | |
from absl import app | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import optax | |
from pytwin.solving_td_x import load_raw_data_xy | |
NUM_CLASSES = 10 # MNIST has 10 classes (hand-written digits). | |
class Batch(NamedTuple): | |
image: np.ndarray # [B, H, W, 1] | |
label: np.ndarray # [B] | |
class TrainingState(NamedTuple): | |
params: hk.Params | |
avg_params: hk.Params | |
opt_state: optax.OptState | |
def create_datapipe_from_xy( | |
x,y, | |
*, | |
shuffle: bool, | |
batch_size: int, | |
) -> Iterator[Batch]: | |
"""Loads the MNIST dataset.""" | |
# ds = tfds.load("mnist:3.*.*", split=split).cache().repeat() | |
ds = tf.data.Dataset.from_tensor_slices((x,y)).repeat() | |
if shuffle: | |
ds = ds.shuffle(10 * batch_size, seed=0) | |
ds = ds.batch(batch_size) | |
ds = ds.map(lambda x,y: Batch(x, y)) | |
return iter(tfds.as_numpy(ds)) | |
def load_my_dataset(dataname: str, **kw): | |
x,y = load_raw_data_xy(dataname) | |
return create_datapipe_from_xy(x,y, **kw) | |
def net_fn(images: jnp.ndarray) -> jnp.ndarray: | |
"""Standard LeNet-300-100 MLP network.""" | |
x = images.astype(jnp.float32) | |
mlp = hk.Sequential([ | |
hk.Flatten(), | |
hk.Linear(1000), jax.nn.leaky_relu, | |
hk.Linear(500), jax.nn.sigmoid, | |
hk.Linear(100), jax.nn.relu, | |
hk.Linear(NUM_CLASSES), | |
]) | |
return mlp(x) | |
def main(_): | |
# First, make the network and optimiser. | |
network = hk.without_apply_rng(hk.transform(net_fn)) | |
optimiser = optax.adam(1e-2) | |
def loss(params: hk.Params, batch: Batch) -> jnp.ndarray: | |
"""Cross-entropy classification loss, regularised by L2 weight decay.""" | |
batch_size, *_ = batch.image.shape | |
y = network.apply(params, batch.image) | |
Y = batch.label | |
l2_regulariser = sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params)) | |
return jnp.mean((y-Y)**2) / batch_size + 1e-5 * l2_regulariser | |
@jax.jit | |
def evaluate(params: hk.Params, batch: Batch) -> jnp.ndarray: | |
"""Evaluation metric (classification accuracy).""" | |
y = network.apply(params, batch.image) | |
Y = batch.label | |
return jnp.mean((y-Y)**2) | |
@jax.jit | |
def update(state: TrainingState, batch: Batch) -> TrainingState: | |
"""Learning rule (stochastic gradient descent).""" | |
grads = jax.grad(loss)(state.params, batch) | |
updates, opt_state = optimiser.update(grads, state.opt_state) | |
params = optax.apply_updates(state.params, updates) | |
# Compute avg_params, the exponential moving average of the "live" params. | |
# We use this only for evaluation (cf. https://doi.org/10.1137/0330046). | |
avg_params = optax.incremental_update( | |
params, state.avg_params, step_size=0.001) | |
return TrainingState(params, avg_params, opt_state) | |
dataset = 'raw_data_3_XY.jld2' | |
dataset = 'raw_data.4.jld2' | |
# Make datasets. | |
train_dataset = load_my_dataset(dataset, shuffle=True, batch_size=100) | |
eval_datasets = { | |
split: load_my_dataset(split, shuffle=False, batch_size=1_000) | |
for split in (dataset, ) | |
} | |
from jax.lib import xla_bridge | |
print(xla_bridge.get_backend().platform) | |
# Initialise network and optimiser; note we draw an input to get shapes. | |
print(next(train_dataset).image.shape) | |
initial_params = network.init( | |
jax.random.PRNGKey(seed=0), next(train_dataset).image) | |
initial_opt_state = optimiser.init(initial_params) | |
state = TrainingState(initial_params, initial_params, initial_opt_state) | |
# Training & evaluation loop. | |
for step in range(6110): | |
if step % 100 == 0 or (step < 200 and step%10==1): | |
# Periodically evaluate classification accuracy on train & test sets. | |
# Note that each evaluation is only on a (large) batch. | |
for split, dataset in eval_datasets.items(): | |
accuracy = np.array(evaluate(state.avg_params, next(dataset))).item() | |
print(f"{step}.: set T acc {accuracy:.3f}") | |
# Do SGD on a batch of training examples. | |
state = update(state, next(train_dataset)) | |
if __name__ == "__main__": | |
main(None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment