Skip to content

Instantly share code, notes, and snippets.

@mirekphd
Forked from richardliaw/fullscript.py
Created June 20, 2020 15:14
Show Gist options
  • Save mirekphd/0a862e536cbe2dff88fe442011225fb1 to your computer and use it in GitHub Desktop.
Save mirekphd/0a862e536cbe2dff88fe442011225fb1 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.optim as optim
from ray import tune
from ray.tune.examples.mnist_pytorch import get_data_loaders, train, test
import ray
import sys
if len(sys.argv) > 1:
ray.init(redis_address=sys.argv[1])
import torch.nn as nn
import torch.nn.functional as F
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc = nn.Linear(192, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, 192)
x = self.fc(x)
return F.log_softmax(x, dim=1)
def train_mnist(config):
model = ConvNet()
train_loader, test_loader = get_data_loaders()
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
for i in range(10):
train(model, optimizer, train_loader, torch.device("cpu"))
acc = test(model, test_loader, torch.device("cpu"))
tune.track.log(mean_accuracy=acc)
if i % 5 == 0:
# This saves the model to the trial directory
torch.save(model.state_dict(), "./model.pth")
from ray.tune.schedulers import ASHAScheduler
search_space = {
"lr": tune.choice([0.001, 0.01, 0.1]),
"momentum": tune.uniform(0.1, 0.9)
}
analysis = tune.run(
train_mnist,
num_samples=30,
scheduler=ASHAScheduler(metric="mean_accuracy", mode="max", grace_period=1),
config=search_space)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment