Skip to content

Instantly share code, notes, and snippets.

@pablo2909
Created December 19, 2022 14:38
Show Gist options
  • Save pablo2909/3a2cec869a43421859520750990f263e to your computer and use it in GitHub Desktop.
Save pablo2909/3a2cec869a43421859520750990f263e to your computer and use it in GitHub Desktop.
fail and passing minst train
import logging
import time
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import optax
from jax import random
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s %(message)s",
datefmt="[%Y-%m-%d %H:%M:%S]",
)
logger = logging.getLogger(__name__)
EPOCHS = 10
LR = 0.1
MOMENTUM = 0.9
BATCH_SIZE = 60000
class MLP(eqx.Module):
mlp: eqx.Module
def __init__(self, *, key: random.PRNGKey) -> None:
self.mlp = eqx.nn.MLP(
in_size=28 * 28,
out_size=10,
width_size=30,
depth=1,
activation=jnn.relu,
key=key,
)
def __call__(self, x):
return self.mlp(x)
def numpy_collate(batch):
if isinstance(batch[0], jnp.ndarray):
return jnp.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return jnp.array(batch)
class NumpyLoader(DataLoader):
def __init__(
self,
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
):
super(self.__class__, self).__init__(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=numpy_collate,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
)
class FlattenAndCast(object):
def __call__(self, pic):
return jnp.ravel(jnp.array(pic, dtype=jnp.float32))
mnist_ds_jax = MNIST(
"/home/pauje/Datasets",
download=True,
transform=FlattenAndCast(),
)
raise NotImplementedError("change dataset path")
train_dl_jax = NumpyLoader(mnist_ds_jax, batch_size=BATCH_SIZE)
optim = optax.sgd(learning_rate=LR, momentum=MOMENTUM)
def accuracy_jax(pred_Y, Y):
pred_Y = jnn.softmax(pred_Y)
pred_Y = jnp.argmax(pred_Y, axis=1)
acc = jnp.reshape(jnp.mean((Y == pred_Y) * 1.0), ())
return acc
def compute_loss(model, X, Y):
pred_Y = jax.vmap(model)(X)
loss = optax.softmax_cross_entropy_with_integer_labels(pred_Y, Y)
return jnp.mean(loss), pred_Y
compute_loss = eqx.filter_value_and_grad(compute_loss, has_aux=True)
@eqx.filter_jit
def make_step(model, X, Y, opt_state):
(loss, pred_Y), grads = compute_loss(model, X, Y)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, pred_Y, model, opt_state
def main(model_jax, opt_state):
start_time_train_jax = time.time()
for epoch in range(EPOCHS):
start = time.time()
for (X, Y) in train_dl_jax:
loss, pred_Y, model_jax, opt_state = make_step(model_jax, X, Y, opt_state)
acc = accuracy_jax(pred_Y, Y)
logger.info(f"Loss = {loss.item()}, Accuracy = {acc.item()}")
logger.info(f"Epoch : {epoch} took {time.time() - start}")
logger.info(f"Training took {time.time() - start_time_train_jax}")
if __name__ == "__main__":
logger.info(jax.default_backend())
model_jax = MLP(key=random.PRNGKey(42))
opt_state = optim.init(eqx.filter(model_jax, eqx.is_array))
main(model_jax, opt_state)
import logging
import time
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import optax
from jax import random
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s %(message)s",
datefmt="[%Y-%m-%d %H:%M:%S]",
)
logger = logging.getLogger(__name__)
EPOCHS = 10
LR = 0.1
MOMENTUM = 0.9
BATCH_SIZE = 10000
class MLP(eqx.Module):
mlp: eqx.Module
def __init__(self, *, key: random.PRNGKey) -> None:
self.mlp = eqx.nn.MLP(
in_size=28 * 28,
out_size=10,
width_size=30,
depth=1,
activation=jnn.relu,
key=key,
)
def __call__(self, x):
return self.mlp(x)
mnist_ds_jax = MNIST(
"/home/pauje/Datasets",
download=True,
transform=transforms.ToTensor(),
)
raise NotImplementedError("Change dataset path")
train_dl_jax = DataLoader(mnist_ds_jax, batch_size=BATCH_SIZE)
optim = optax.sgd(learning_rate=LR, momentum=MOMENTUM)
def accuracy_jax(pred_Y, Y):
pred_Y = jnn.softmax(pred_Y)
pred_Y = jnp.argmax(pred_Y, axis=1)
acc = jnp.reshape(jnp.mean((Y == pred_Y) * 1.0), ())
return acc
def compute_loss(model, X, Y):
pred_Y = jax.vmap(model)(X)
loss = optax.softmax_cross_entropy_with_integer_labels(pred_Y, Y)
return jnp.mean(loss), pred_Y
compute_loss = eqx.filter_value_and_grad(compute_loss, has_aux=True)
@eqx.filter_jit
def make_step(model, X, Y, opt_state):
(loss, pred_Y), grads = compute_loss(model, X, Y)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, pred_Y, model, opt_state
def main(model_jax, opt_state):
start_time_train_jax = time.time()
for epoch in range(EPOCHS):
start = time.time()
for (X, Y) in train_dl_jax:
X = X.numpy().reshape(X.shape[0], -1)
Y = Y.numpy()
loss, pred_Y, model_jax, opt_state = make_step(model_jax, X, Y, opt_state)
acc = accuracy_jax(pred_Y, Y)
logger.info(f"Loss = {loss.item()}, Accuracy = {acc.item()}")
logger.info(f"Epoch : {epoch} took {time.time() - start}")
logger.info(f"Training took {time.time() - start_time_train_jax}")
if __name__ == "__main__":
logger.info(jax.default_backend())
model_jax = MLP(key=random.PRNGKey(42))
opt_state = optim.init(eqx.filter(model_jax, eqx.is_array))
main(model_jax, opt_state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment