-
-
Save pablo2909/3a2cec869a43421859520750990f263e to your computer and use it in GitHub Desktop.
fail and passing minst train
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import time | |
import equinox as eqx | |
import jax | |
import jax.nn as jnn | |
import jax.numpy as jnp | |
import optax | |
from jax import random | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from torchvision.datasets import MNIST | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s] %(name)s %(message)s", | |
datefmt="[%Y-%m-%d %H:%M:%S]", | |
) | |
logger = logging.getLogger(__name__) | |
EPOCHS = 10 | |
LR = 0.1 | |
MOMENTUM = 0.9 | |
BATCH_SIZE = 60000 | |
class MLP(eqx.Module): | |
mlp: eqx.Module | |
def __init__(self, *, key: random.PRNGKey) -> None: | |
self.mlp = eqx.nn.MLP( | |
in_size=28 * 28, | |
out_size=10, | |
width_size=30, | |
depth=1, | |
activation=jnn.relu, | |
key=key, | |
) | |
def __call__(self, x): | |
return self.mlp(x) | |
def numpy_collate(batch): | |
if isinstance(batch[0], jnp.ndarray): | |
return jnp.stack(batch) | |
elif isinstance(batch[0], (tuple, list)): | |
transposed = zip(*batch) | |
return [numpy_collate(samples) for samples in transposed] | |
else: | |
return jnp.array(batch) | |
class NumpyLoader(DataLoader): | |
def __init__( | |
self, | |
dataset, | |
batch_size=1, | |
shuffle=False, | |
sampler=None, | |
batch_sampler=None, | |
num_workers=0, | |
pin_memory=False, | |
drop_last=False, | |
timeout=0, | |
worker_init_fn=None, | |
): | |
super(self.__class__, self).__init__( | |
dataset, | |
batch_size=batch_size, | |
shuffle=shuffle, | |
sampler=sampler, | |
batch_sampler=batch_sampler, | |
num_workers=num_workers, | |
collate_fn=numpy_collate, | |
pin_memory=pin_memory, | |
drop_last=drop_last, | |
timeout=timeout, | |
worker_init_fn=worker_init_fn, | |
) | |
class FlattenAndCast(object): | |
def __call__(self, pic): | |
return jnp.ravel(jnp.array(pic, dtype=jnp.float32)) | |
mnist_ds_jax = MNIST( | |
"/home/pauje/Datasets", | |
download=True, | |
transform=FlattenAndCast(), | |
) | |
raise NotImplementedError("change dataset path") | |
train_dl_jax = NumpyLoader(mnist_ds_jax, batch_size=BATCH_SIZE) | |
optim = optax.sgd(learning_rate=LR, momentum=MOMENTUM) | |
def accuracy_jax(pred_Y, Y): | |
pred_Y = jnn.softmax(pred_Y) | |
pred_Y = jnp.argmax(pred_Y, axis=1) | |
acc = jnp.reshape(jnp.mean((Y == pred_Y) * 1.0), ()) | |
return acc | |
def compute_loss(model, X, Y): | |
pred_Y = jax.vmap(model)(X) | |
loss = optax.softmax_cross_entropy_with_integer_labels(pred_Y, Y) | |
return jnp.mean(loss), pred_Y | |
compute_loss = eqx.filter_value_and_grad(compute_loss, has_aux=True) | |
@eqx.filter_jit | |
def make_step(model, X, Y, opt_state): | |
(loss, pred_Y), grads = compute_loss(model, X, Y) | |
updates, opt_state = optim.update(grads, opt_state) | |
model = eqx.apply_updates(model, updates) | |
return loss, pred_Y, model, opt_state | |
def main(model_jax, opt_state): | |
start_time_train_jax = time.time() | |
for epoch in range(EPOCHS): | |
start = time.time() | |
for (X, Y) in train_dl_jax: | |
loss, pred_Y, model_jax, opt_state = make_step(model_jax, X, Y, opt_state) | |
acc = accuracy_jax(pred_Y, Y) | |
logger.info(f"Loss = {loss.item()}, Accuracy = {acc.item()}") | |
logger.info(f"Epoch : {epoch} took {time.time() - start}") | |
logger.info(f"Training took {time.time() - start_time_train_jax}") | |
if __name__ == "__main__": | |
logger.info(jax.default_backend()) | |
model_jax = MLP(key=random.PRNGKey(42)) | |
opt_state = optim.init(eqx.filter(model_jax, eqx.is_array)) | |
main(model_jax, opt_state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import time | |
import equinox as eqx | |
import jax | |
import jax.nn as jnn | |
import jax.numpy as jnp | |
import optax | |
from jax import random | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from torchvision.datasets import MNIST | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s] %(name)s %(message)s", | |
datefmt="[%Y-%m-%d %H:%M:%S]", | |
) | |
logger = logging.getLogger(__name__) | |
EPOCHS = 10 | |
LR = 0.1 | |
MOMENTUM = 0.9 | |
BATCH_SIZE = 10000 | |
class MLP(eqx.Module): | |
mlp: eqx.Module | |
def __init__(self, *, key: random.PRNGKey) -> None: | |
self.mlp = eqx.nn.MLP( | |
in_size=28 * 28, | |
out_size=10, | |
width_size=30, | |
depth=1, | |
activation=jnn.relu, | |
key=key, | |
) | |
def __call__(self, x): | |
return self.mlp(x) | |
mnist_ds_jax = MNIST( | |
"/home/pauje/Datasets", | |
download=True, | |
transform=transforms.ToTensor(), | |
) | |
raise NotImplementedError("Change dataset path") | |
train_dl_jax = DataLoader(mnist_ds_jax, batch_size=BATCH_SIZE) | |
optim = optax.sgd(learning_rate=LR, momentum=MOMENTUM) | |
def accuracy_jax(pred_Y, Y): | |
pred_Y = jnn.softmax(pred_Y) | |
pred_Y = jnp.argmax(pred_Y, axis=1) | |
acc = jnp.reshape(jnp.mean((Y == pred_Y) * 1.0), ()) | |
return acc | |
def compute_loss(model, X, Y): | |
pred_Y = jax.vmap(model)(X) | |
loss = optax.softmax_cross_entropy_with_integer_labels(pred_Y, Y) | |
return jnp.mean(loss), pred_Y | |
compute_loss = eqx.filter_value_and_grad(compute_loss, has_aux=True) | |
@eqx.filter_jit | |
def make_step(model, X, Y, opt_state): | |
(loss, pred_Y), grads = compute_loss(model, X, Y) | |
updates, opt_state = optim.update(grads, opt_state) | |
model = eqx.apply_updates(model, updates) | |
return loss, pred_Y, model, opt_state | |
def main(model_jax, opt_state): | |
start_time_train_jax = time.time() | |
for epoch in range(EPOCHS): | |
start = time.time() | |
for (X, Y) in train_dl_jax: | |
X = X.numpy().reshape(X.shape[0], -1) | |
Y = Y.numpy() | |
loss, pred_Y, model_jax, opt_state = make_step(model_jax, X, Y, opt_state) | |
acc = accuracy_jax(pred_Y, Y) | |
logger.info(f"Loss = {loss.item()}, Accuracy = {acc.item()}") | |
logger.info(f"Epoch : {epoch} took {time.time() - start}") | |
logger.info(f"Training took {time.time() - start_time_train_jax}") | |
if __name__ == "__main__": | |
logger.info(jax.default_backend()) | |
model_jax = MLP(key=random.PRNGKey(42)) | |
opt_state = optim.init(eqx.filter(model_jax, eqx.is_array)) | |
main(model_jax, opt_state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment