Last active
February 28, 2021 19:35
-
-
Save david1542/a0831761de35e7805cd9c99225076cb9 to your computer and use it in GitHub Desktop.
Section 7 - DL HW 1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
def evaluate_models(models_to_evaluate): | |
data_dir = polit_path | |
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), | |
data_transforms[x]) | |
for x in ['train', 'val']} | |
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16, | |
shuffle=True, num_workers=4) | |
for x in ['train', 'val']} | |
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} | |
class_names = image_datasets['train'].classes | |
# Models metrics map | |
metrics = {} | |
for pack in models_to_evaluate: | |
model_name = pack['model_name'] | |
model_class = pack['model_class'] | |
replace_layer = pack['replace_layer'] | |
print('------------------------') | |
print(f'Evaluating {model_name}...') | |
print('------------------------') | |
# Perform the same setup as before | |
model = model_class(pretrained=True) | |
# Freeze all the network parameters | |
layers = list(model.children()) | |
# Take the last layer, and build on top of it | |
last_layer = layers[-1] | |
if hasattr(last_layer, 'in_features'): | |
num_ftrs = last_layer.in_features | |
else: | |
num_ftrs = last_layer[-1].in_features | |
# Override the last fully-connected layer with a new one, | |
# that outputs 9 different values | |
replace_layer(model, num_ftrs, len(class_names)) | |
# If a GPU is available, make the model use it | |
model = model.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) | |
num_epochs = 20 | |
model, performance = train_model(model, | |
dataloaders, | |
criterion, | |
optimizer_ft, | |
exp_lr_scheduler, | |
dataset_sizes, | |
num_epochs=num_epochs) | |
# Save the model and performance for later | |
model_file = os.path.join(models_path, f'model_{model_name}.pth') | |
torch.save(model.state_dict(), model_file) | |
with open(os.path.join(metrics_path, f'metrics_{model_name}.json'), 'w') as f: | |
json.dump(performance, f) | |
metrics[model_name] = performance | |
return metrics | |
#################### | |
# Run the experiment | |
#################### | |
def resnet18_replace(model, num_ftrs, n_classes): | |
model.fc = nn.Linear(num_ftrs, n_classes) | |
def vgg16_replace(model, num_ftrs, n_classes): | |
model.classifier[6] = nn.Linear(num_ftrs, n_classes) | |
models_to_evaluate = [ | |
{ | |
'model_name': 'vgg16', | |
'model_class': models.vgg16, | |
'replace_layer': vgg16_replace | |
}, | |
{ | |
'model_name': 'resnet18', | |
'model_class': models.resnet18, | |
'replace_layer': resnet18_replace | |
} | |
] | |
metrics = evaluate_models(models_to_evaluate) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment