Skip to content

Instantly share code, notes, and snippets.

@HanaanY
Created September 6, 2018 13:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HanaanY/e5adf4ba32aab7ebb9e235e003a3f96e to your computer and use it in GitHub Desktop.
Save HanaanY/e5adf4ba32aab7ebb9e235e003a3f96e to your computer and use it in GitHub Desktop.
from Project.DenseNet.densenet import densenet121_model
from Project.ss150.ss150 import DataGenerator
from keras.utils.io_utils import HDF5Matrix
from keras.utils import multi_gpu_model
from keras.optimizers import SGD
from keras import backend as K
from sklearn import metrics
import random
import numpy as np
import click
import pandas as pd
import os.path
random.seed(52)
@click.command()
@click.option('--data_path', default='/mnt/storage/home/hy16732/scratch/balanced.h5',
help='hdf5 datafile')
@click.option('--batch_size', default=50, help='number of batches')
@click.option('--weights', default='weights.59-0.12.hdf5', help='load weights')
@click.option('--predict_dir', default='/mnt/storage/home/hy16732/scratch/D121/Results/20180826-103307', help='Predict directory')
@click.option('--outfile', default='metrics', help='outfile')
@click.option('--metadata', default='/mnt/storage/home/hy16732/scratch/giraffe_test.csv', help='metadata with more info about the data point')
@click.option('--tl/--no-tl', default=True, help='TL with ImageNet weights and normalisation')
def main(data_path, batch_size, weights, predict_dir, outfile, metadata, tl):
weights_path = os.path.join(predict_dir, weights)
model = densenet121_model(img_rows=224, img_cols=224,
color_type=3, num_classes=2,
transfer_learning=False)
parallel_model = multi_gpu_model(model, gpus=2, cpu_merge=True)
#below hack taken from stackoverflow verbatim https://stackoverflow.com/questions/48198031
def get_lr(optimizer):
def lr(y_true, y_pred):
return optimizer.lr
return lr
# Learning rate is changed to 0.001
sgd = SGD(lr=lr, decay=1e-6, momentum=0.90, nesterov=True)
#custom metric taken from stackoverflow
learn_rate = get_lr(sgd)
parallel_model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy', learn_rate])
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy', learn_rate])
parallel_model.load_weights(weights_path, by_name=False)
model.save_weights('test.hdf5')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment