Created
September 6, 2018 13:19
-
-
Save HanaanY/e5adf4ba32aab7ebb9e235e003a3f96e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from Project.DenseNet.densenet import densenet121_model | |
from Project.ss150.ss150 import DataGenerator | |
from keras.utils.io_utils import HDF5Matrix | |
from keras.utils import multi_gpu_model | |
from keras.optimizers import SGD | |
from keras import backend as K | |
from sklearn import metrics | |
import random | |
import numpy as np | |
import click | |
import pandas as pd | |
import os.path | |
random.seed(52) | |
@click.command() | |
@click.option('--data_path', default='/mnt/storage/home/hy16732/scratch/balanced.h5', | |
help='hdf5 datafile') | |
@click.option('--batch_size', default=50, help='number of batches') | |
@click.option('--weights', default='weights.59-0.12.hdf5', help='load weights') | |
@click.option('--predict_dir', default='/mnt/storage/home/hy16732/scratch/D121/Results/20180826-103307', help='Predict directory') | |
@click.option('--outfile', default='metrics', help='outfile') | |
@click.option('--metadata', default='/mnt/storage/home/hy16732/scratch/giraffe_test.csv', help='metadata with more info about the data point') | |
@click.option('--tl/--no-tl', default=True, help='TL with ImageNet weights and normalisation') | |
def main(data_path, batch_size, weights, predict_dir, outfile, metadata, tl): | |
weights_path = os.path.join(predict_dir, weights) | |
model = densenet121_model(img_rows=224, img_cols=224, | |
color_type=3, num_classes=2, | |
transfer_learning=False) | |
parallel_model = multi_gpu_model(model, gpus=2, cpu_merge=True) | |
#below hack taken from stackoverflow verbatim https://stackoverflow.com/questions/48198031 | |
def get_lr(optimizer): | |
def lr(y_true, y_pred): | |
return optimizer.lr | |
return lr | |
# Learning rate is changed to 0.001 | |
sgd = SGD(lr=lr, decay=1e-6, momentum=0.90, nesterov=True) | |
#custom metric taken from stackoverflow | |
learn_rate = get_lr(sgd) | |
parallel_model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy', learn_rate]) | |
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy', learn_rate]) | |
parallel_model.load_weights(weights_path, by_name=False) | |
model.save_weights('test.hdf5') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment