Skip to content

Instantly share code, notes, and snippets.

@pablo2909
Created December 20, 2022 15:36
Show Gist options
  • Save pablo2909/91127b9c7cb441b3b897bbebd9c0eff1 to your computer and use it in GitHub Desktop.
Save pablo2909/91127b9c7cb441b3b897bbebd9c0eff1 to your computer and use it in GitHub Desktop.
import jax.numpy as jnp
from jax import random
class Dataset:
def __init__(self, x, y) -> None:
self.x, self.y = x, y
def __len__(self):
return self.x.shape[0]
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
class Sampler:
def __init__(self, dataset, batch_size, shuffle=False, *, key):
self.n = len(dataset)
self.batch_size = batch_size
self.shuffle = shuffle
self.key = key
def __iter__(self):
self.idxs = (
random.permutation(self.key, self.n) if self.shuffle else jnp.arange(self.n)
)
for i in range(0, self.n, self.batch_size):
yield self.idxs[i : i + self.batch_size]
def collate(b):
xs, ys = zip(*b)
return jnp.stack(xs), jnp.stack(ys)
class DataLoader:
def __init__(
self,
dataset,
batch_size,
collate_fn=collate,
shuffle=False,
transform=lambda x: x,
transform_batch=lambda x: x,
*,
key,
):
self.sampler = Sampler(dataset, batch_size, shuffle, key=key)
self.collate_fn = collate_fn
self.dataset = dataset
self.transform = transform
self.transform_batch = transform_batch
def __iter__(self):
for s in self.sampler:
yield self.transform_batch(
self.collate_fn([self.transform(self.dataset[i]) for i in s])
)
import logging
import time
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import optax
from jax import random
from torch.utils import data
from torchvision import transforms
# from torch.utils.data import DataLoader
# from torchvision import transforms
from torchvision.datasets import MNIST
from .helpers_datasets import DataLoader, Dataset
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s %(message)s",
datefmt="[%Y-%m-%d %H:%M:%S]",
)
logger = logging.getLogger(__name__)
EPOCHS = 10
LR = 0.1
MOMENTUM = 0.9
BATCH_SIZE = 10000
class MLP(eqx.Module):
mlp: eqx.Module
def __init__(self, *, key: random.PRNGKey) -> None:
self.mlp = eqx.nn.MLP(
in_size=28 * 28,
out_size=10,
width_size=30,
depth=1,
activation=jnn.relu,
key=key,
)
def __call__(self, x):
return self.mlp(x)
mnist_ds_jax = MNIST(
"/home/pauje/Datasets",
download=True,
transform=transforms.ToTensor(),
)
raise NotImplementedError("Replace path dataset")
Xtrain, Ytrain = next(iter(data.DataLoader(mnist_ds_jax, batch_size=60000)))
Xtrain, Ytrain = jnp.array(Xtrain.numpy(), dtype=jnp.float32), jnp.array(Ytrain)
# Xtrain, Ytrain = jnp.array(mnist_ds_jax.data.numpy(), dtype=jnp.float32), jnp.array( # Uncomment to fail training
# mnist_ds_jax.targets # Uncomment to fail training
# ) # Uncomment to fail training
ds = Dataset(Xtrain, Ytrain)
train_dl_jax = DataLoader(ds, BATCH_SIZE, key=random.PRNGKey(42))
optim = optax.sgd(learning_rate=LR, momentum=MOMENTUM)
def accuracy_jax(pred_Y, Y):
pred_Y = jnn.softmax(pred_Y)
pred_Y = jnp.argmax(pred_Y, axis=1)
acc = jnp.reshape(jnp.mean((Y == pred_Y) * 1.0), ())
return acc
def compute_loss(model, X, Y):
X = X.reshape(X.shape[0], -1)
pred_Y = jax.vmap(model)(X)
loss = optax.softmax_cross_entropy_with_integer_labels(pred_Y, Y)
return jnp.mean(loss), pred_Y
compute_loss = eqx.filter_value_and_grad(compute_loss, has_aux=True)
@eqx.filter_jit
def make_step(model, X, Y, opt_state):
(loss, pred_Y), grads = compute_loss(model, X, Y)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, pred_Y, model, opt_state
def main(model_jax, opt_state):
start_time_train_jax = time.time()
for epoch in range(EPOCHS):
start = time.time()
for (X, Y) in train_dl_jax:
loss, pred_Y, model_jax, opt_state = make_step(model_jax, X, Y, opt_state)
acc = accuracy_jax(pred_Y, Y)
logger.info(f"Loss = {loss.item()}, Accuracy = {acc.item()}")
logger.info(f"Epoch : {epoch} took {time.time() - start}")
logger.info(f"Training took {time.time() - start_time_train_jax}")
if __name__ == "__main__":
logger.info(jax.default_backend())
model_jax = MLP(key=random.PRNGKey(42))
opt_state = optim.init(eqx.filter(model_jax, eqx.is_array))
main(model_jax, opt_state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment