Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Created February 27, 2021 10:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save buttercutter/ae0e48974ccf4bdee07c9d69148cf21b to your computer and use it in GitHub Desktop.
Save buttercutter/ae0e48974ccf4bdee07c9d69148cf21b to your computer and use it in GitHub Desktop.
DARTS: DIFFERENTIABLE ARCHITECTURE SEARCH https://arxiv.org/pdf/1806.09055.pdf
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
USE_CUDA = torch.cuda.is_available()
# https://arxiv.org/pdf/1806.09055.pdf#page=12
TEST_DATASET_RATIO = 0.5 # 50 percent of the dataset is dedicated for testing purpose
SIZE_OF_HIDDEN_LAYERS = 64
NUM_EPOCHS = 50
LEARNING_RATE = 0.025
MOMENTUM = 0.9
NUM_OF_CHANNELS = 16
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
valset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
class Net1(nn.Module):
def __init__(self):
super(Net1, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class Net2(nn.Module):
def __init__(self):
super(Net2, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# https://translate.google.com/translate?sl=auto&tl=en&u=http://khanrc.github.io/nas-4-darts-tutorial.html
def train():
net1 = Net1() # for Ltrain(w±, alpha)
net2 = Net2() # for Lval(w*, alpha)
criterion = nn.CrossEntropyLoss()
optimizer1 = optim.SGD(net1.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
optimizer2 = optim.SGD(net2.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
TRAIN_BATCH_SIZE = int(len(trainset) * (1 - TEST_DATASET_RATIO))
for epoch in range(NUM_EPOCHS):
for i, train_data, j, val_data in enumerate(zip(trainloader, valloader)):
train_inputs, train_labels = train_data
val_inputs, val_labels = val_data
# do train thing
# zero the parameter gradients
optimizer1.zero_grad()
optimizer2.zero_grad()
# forward + backward + optimize
outputs1 = net1(train_inputs)
outputs2 = net2(val_inputs)
loss1 = criterion(outputs1, train_labels)
loss2 = criterion(outputs2, val_labels)
loss1.backward()
loss2.backward()
optimizer1.step()
optimizer2.step()
# DARTS's approximate architecture gradient. Refer to equation (8)
# needs to save intermediate trained model for Lval
path = './net1.pth'
torch.save(net1, path)
epsilon = 0.01/torch.norm()
= (loss1 - loss2)/2*epsilon
# do test thing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment