Created
January 8, 2021 07:55
-
-
Save djollet/8a73ba0157feb9f1a303833763cd0e43 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
import torch | |
from torchvision import transforms | |
import numpy as np | |
from torch.utils.data import DataLoader, SubsetRandomSampler | |
from torch import nn | |
from efficientnet_pytorch import EfficientNet | |
import time | |
import ray.tune as tune | |
import os | |
from ray.tune import CLIReporter | |
from ray.tune.schedulers import PopulationBasedTraining | |
from ray.tune.stopper import Stopper | |
class Classifier(nn.Module): | |
def __init__(self,n_classes, val_dropout): | |
super(Classifier, self).__init__() | |
self.effnet = EfficientNet.from_pretrained('efficientnet-b0') | |
for param in self.effnet.parameters(): | |
param.requires_grad = False | |
self.l1 = nn.Linear(1000 , 256) | |
self.dropout = nn.Dropout(val_dropout) | |
self.l2 = nn.Linear(256,n_classes) | |
self.sm = nn.LogSoftmax(1) | |
def forward(self, input): | |
x = self.effnet(input) | |
x = x.view(x.size(0),-1) | |
x = self.dropout(self.l1(x)) | |
x = self.sm(self.l2(x)) | |
return x | |
class CustomStopper(Stopper): | |
def __init__(self): | |
self.should_stop = False | |
def __call__(self, trial_id, result): | |
max_iter = 20 | |
if not self.should_stop and result["accuracy"] > 0.96: | |
self.should_stop = True | |
return self.should_stop or result["training_iteration"] >= max_iter | |
def stop_all(self): | |
return self.should_stop | |
def data_loader(train_data,test_data = None , valid_size = None , batch_size = 32, pin_memory = False, shuffle = True): | |
train_loader = DataLoader(train_data, batch_size = batch_size , shuffle = shuffle, pin_memory=pin_memory) | |
if(test_data == None and valid_size == None): | |
dataloaders = {'train':train_loader} | |
return dataloaders | |
if(test_data == None and valid_size!=None): | |
data_len = len(train_data) | |
indices = list(range(data_len)) | |
np.random.shuffle(indices) | |
split1 = int(np.floor(valid_size * data_len)) | |
valid_idx , test_idx = indices[:split1], indices[split1:] | |
valid_sampler = SubsetRandomSampler(valid_idx) | |
valid_loader = DataLoader(train_data, batch_size= batch_size, sampler=valid_sampler) | |
dataloaders = {'train':train_loader,'val':valid_loader} | |
return dataloaders | |
if(test_data != None and valid_size!=None): | |
data_len = len(test_data) | |
indices = list(range(data_len)) | |
np.random.shuffle(indices) | |
split1 = int(np.floor(valid_size * data_len)) | |
valid_idx , test_idx = indices[:split1], indices[split1:] | |
valid_sampler = SubsetRandomSampler(valid_idx) | |
test_sampler = SubsetRandomSampler(test_idx) | |
valid_loader = DataLoader(test_data, batch_size= batch_size, sampler=valid_sampler) | |
test_loader = DataLoader(test_data, batch_size= batch_size, sampler=test_sampler) | |
dataloaders = {'train':train_loader,'val':valid_loader,'test':test_loader} | |
return dataloaders | |
def train_model(params, checkpoint_dir=None, num_epochs=25): | |
'Initialize model' | |
global device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = Classifier(n_classes=9, val_dropout=params['dropout']).to(device) | |
'Load data' | |
transformations = transforms.Compose([transforms.Resize((params['img_size'], params['img_size'])), | |
transforms.ToTensor()]) # additional augmentation might be conducted here | |
train_data = torchvision.datasets.ImageFolder(root='C:/Users/Dirk/PycharmProjects/ImageAnalysis/Carrot/Train3/', | |
transform=transformations) | |
test_data = torchvision.datasets.ImageFolder(root='C:/Users/Dirk/PycharmProjects/ImageAnalysis/Carrot/Test/', | |
transform=transformations) | |
global dataloaders | |
dataloaders = data_loader(train_data, test_data, valid_size=0.2, batch_size=params['batch_size'], shuffle=params['shuffle']) | |
if params['optimizer'] == 'adam': | |
optimizer = torch.optim.Adam(model.parameters(), lr=params['initial_lr'], eps=params['eps'], | |
weight_decay=params['weight_decay']) | |
elif params['optimizer'] == 'SGD': | |
optimizer = torch.optim.SGD(model.parameters(), lr=params['initial_lr'], momentum=params['momentum'], | |
weight_decay=params['weight_decay']) | |
else: | |
pass | |
# criterion | |
if params['criterion'] == 'NLL': | |
criterion = torch.nn.NLLLoss() | |
elif params['criterion'] == 'cross-entropy': | |
criterion = torch.nn.CrossEntropyLoss() | |
if checkpoint_dir: | |
model_state, optimizer_state = torch.load( | |
os.path.join(checkpoint_dir, "checkpoint")) | |
model.load_state_dict(model_state) | |
optimizer.load_state_dict(optimizer_state) | |
since = time.time() | |
for epoch in range(num_epochs): # loop over the dataset multiple times | |
time.sleep(5) | |
running_loss = 0.0 | |
epoch_steps = 0 | |
for i, data in enumerate(dataloaders['train'], 0): | |
# get the inputs; data is a list of [inputs, labels] | |
inputs, labels = data | |
inputs, labels = inputs.to(device), labels.to(device) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward + backward + optimize | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
# print statistics | |
running_loss += loss.item() | |
epoch_steps += 1 | |
if i % 2000 == 1999: # print every 2000 mini-batches | |
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, | |
running_loss / epoch_steps)) | |
running_loss = 0.0 | |
# Validation loss | |
val_loss = 0.0 | |
val_steps = 0 | |
total = 0 | |
correct = 0 | |
for i, data in enumerate(dataloaders['val'], 0): | |
with torch.no_grad(): | |
inputs, labels = data | |
inputs, labels = inputs.to(device), labels.to(device) | |
outputs = model(inputs) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
loss = criterion(outputs, labels) | |
val_loss += loss.cpu().numpy() | |
val_steps += 1 | |
with tune.checkpoint_dir(step=epoch) as checkpoint_dir: | |
path = os.path.join(checkpoint_dir, "checkpoint") | |
torch.save((model.state_dict(), optimizer.state_dict()), path) | |
time.sleep(5) | |
tune.report(loss=(val_loss / val_steps), accuracy=correct / total) | |
print() | |
time_elapsed = time.time() - since | |
print('Training complete in {:.0f}m {:.0f}s'.format( | |
time_elapsed // 60, time_elapsed % 60)) | |
def test(model, device): | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for data in dataloaders['test']: | |
images, labels = data | |
images, labels = images.to(device), labels.to(device) | |
outputs = model(images) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
return correct / total | |
def main(num_samples=10, pertubation_interval=10, gpus_per_trial=1): | |
params = {'batch_size': tune.choice([2, 4, 8, 16, 32]), | |
'img_size': tune.choice([150, 208]), | |
'initial_lr': tune.loguniform(1e-4, 1e-1), | |
'shuffle': tune.choice([True, False]), | |
'dropout': tune.choice([0.1, 0.2, 0.3, 0.5]), | |
'optimizer': tune.choice(['adam', 'SGD']), | |
'weight_decay': tune.choice([0, 0.005, 0.01, 0.05]), | |
'momentum': tune.uniform(0.5, 1.5), | |
'eps': tune.choice([0,1e-10,1e-09,1e-08,1e-07, 1e-06, 1e-05]), | |
'criterion': tune.choice(['NLL', 'cross-entropy']) | |
} | |
scheduler = PopulationBasedTraining( | |
time_attr='training_iteration', | |
metric="accuracy", | |
mode="max", | |
perturbation_interval=pertubation_interval, | |
hyperparam_mutations={'batch_size': tune.choice([2, 4, 8, 16, 32]), | |
'img_size': tune.choice([150, 208]), | |
'initial_lr': tune.loguniform(1e-4, 1e-1), | |
'dropout': tune.choice([0.1, 0.2, 0.3, 0.5]), | |
'weight_decay': tune.choice([0, 0.005, 0.01, 0.05]), | |
'momentum': tune.uniform(0.5, 1.5), | |
'eps': tune.choice([0,1e-10,1e-09,1e-08,1e-07, 1e-06, 1e-05]) | |
}) | |
stopper = CustomStopper() | |
reporter = CLIReporter( | |
metric_columns=["loss", "accuracy", "training_iteration"]) | |
result = tune.run(train_model, | |
# keep_checkpoints_num=1, | |
# checkpoint_score_attr="accuracy", | |
resources_per_trial={"cpu": 4, "gpu": gpus_per_trial}, | |
config=params, | |
num_samples=num_samples, | |
scheduler=scheduler, | |
progress_reporter=reporter, | |
stop=stopper, | |
verbose=1 | |
) | |
best_trial = result.get_best_trial("loss", "min", "last") | |
print("Best trial config: {}".format(best_trial.config)) | |
print("Best trial final validation loss: {}".format( | |
best_trial.last_result["loss"])) | |
print("Best trial final validation accuracy: {}".format( | |
best_trial.last_result["accuracy"])) | |
best_trained_model = Classifier(9, best_trial.config["dropout"]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
best_trained_model.to(device) | |
best_checkpoint_dir = best_trial.checkpoint.value | |
model_state, optimizer_state = torch.load(os.path.join(best_checkpoint_dir, "checkpoint")) | |
best_trained_model.load_state_dict(model_state) | |
test_acc = test(best_trained_model, device) | |
print("Best trial test set accuracy: {}".format(test_acc)) | |
if __name__ == "__main__": | |
# You can change the number of GPUs per trial here: | |
main(num_samples=3, pertubation_interval=3, gpus_per_trial=1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment