Skip to content

Instantly share code, notes, and snippets.

@devforfu
Last active December 9, 2018 14:19
Show Gist options
  • Save devforfu/9991aee020999b8dc2631f3c71faa7b6 to your computer and use it in GitHub Desktop.
Save devforfu/9991aee020999b8dc2631f3c71faa7b6 to your computer and use it in GitHub Desktop.
Training loop example
data_path = Path.home()/'data'/'mnist'
mnist_stats = ((0.15,), (0.15,))
epochs = 3
train_ds = MNIST(
data_path,
train=True,
download=True,
transform=T.Compose([
T.RandomAffine(5, translate=(0.05, 0.05), scale=(0.9, 1.1)),
T.ToTensor(),
T.Normalize(*mnist_stats)
])
)
valid_ds = MNIST(
data_path,
train=False,
transform=T.Compose([
T.ToTensor(),
T.Normalize(*mnist_stats)
])
)
phases = make_phases(train_ds, valid_ds, bs=1024, n_jobs=4)
model = Net()
opt = optim.Adam(model.parameters(), lr=1e-2)
cb = CallbacksGroup([
RollingLoss(),
Accuracy(),
Scheduler(
OneCycleSchedule(t=len(phases[0].loader) * epochs),
mode='batch'
),
StreamLogger()
])
train(model, opt, phases, cb, epochs=epochs, loss_fn=F.cross_entropy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment