Created
March 4, 2018 14:32
-
-
Save ahalterman/821b4db869160aa5ecf9ff1b60d7f91c to your computer and use it in GitHub Desktop.
Managing machine learning experiments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# many lines omitted above | |
def make_log(experiment_dir, X_train, X_test, Y_test, model, hist, custom_model): | |
now = datetime.datetime.now() | |
now = now.strftime("%Y-%m-%d %H:%M:%S") | |
# get last commit hash | |
commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip() | |
# get precision and recall at a range of cutpoints | |
cutoffs = [0.01, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60] | |
precrecs = [precision_recall(X_test, Y_test, model, i) for i in cutoffs] | |
# do some nice pandas formatting | |
precrec_table = str(pd.DataFrame(precrecs, columns = ['cutpoint', 'precision', 'recall', 'fscore', 'div_zero_perc'])) | |
hist_table = str(pd.DataFrame(hist.history)) | |
run_info = """Model ran at {0} using code with latest commit {1}. | |
Custom model: {2}\n | |
Total training data was {3} examples, with {4} features. | |
Evaluated on {5} evaluation examples.\n | |
Precision/recall metrics:\n{6}\n | |
Training history:\n {7}""".format(now, | |
commit, | |
str(custom_model), | |
X_train.shape[0], | |
X_train.shape[2], | |
X_test.shape[0], | |
precrec_table, | |
hist_table) | |
with open(experiment_dir + "run_info.txt", "w") as f: | |
f.write(run_info) | |
@plac.annotations( | |
experiment_dir=("Location of the run's folder with a config file", "option", "i", str)) | |
def main(experiment_dir): | |
git_status = str(subprocess.check_output(['git', 'status']).strip()) | |
if bool(re.search("event_model.py", git_status)): | |
print("You have uncommitted changes to `event_model.py`. Please commit them before proceeding to ensure reproducibility.") | |
quit = input("Type 'testing' to continue or anything else to quit: ") | |
if quit != "testing": | |
print("Bye!") | |
sys.exit(0) | |
config = ConfigParser() | |
config.read(experiment_dir + "config.txt") | |
print("Importing data...") | |
if config['Data']['use_cache']: | |
print("Using cached formatted data. This is much faster, but changes to the feature factory won't appear if you do this!") | |
cache_loc = config['Data']['cache_loc'] | |
with open(cache_loc, "rb") as f: | |
formatted = pickle.load(f) | |
else: | |
print("Regenerating formatted data from scratch.") | |
try: | |
nlp = spacy.load(str( config['Model']['nlp_model'])) | |
except: | |
print("Tried to load custom spaCy model but failed. Falling back to en_core_web_sm") | |
formatted = import_data(minerva_dir = config['Data']['minerva_dir'], | |
prodigy_dir = config['Data']['prodigy_dir']) | |
encoder = Encoder() | |
X, _, Y = make_CNN_matrix(formatted, encoder) | |
X_train, Y_train, X_test, Y_test = train_test_split(X, Y) | |
sys.path.append(experiment_dir) | |
try: | |
import custom_model | |
model = custom_model.make_CNN_model(X, Y, | |
filter_size = int(config['Model']['filter_size']), | |
conv_dropout = float(config['Model']['conv_dropout']), | |
conv_activation = str(config['Model']['conv_activation']), | |
dense_units = int(config['Model']['dense_units']), | |
dense_dropout = float(config['Model']['dense_dropout']), | |
dense_activation = str(config['Model']['dense_activation'])) | |
print("Using a custom model.") | |
custom_model = True | |
except ImportError: | |
custom_model = False | |
model = make_CNN_model(X, Y, | |
filter_size = int(config['Model']['filter_size']), | |
conv_dropout = float(config['Model']['conv_dropout']), | |
conv_activation = str(config['Model']['conv_activation']), | |
dense_units = int(config['Model']['dense_units']), | |
dense_dropout = float(config['Model']['dense_dropout']), | |
dense_activation = str(config['Model']['dense_activation'])) | |
es = config['Model']['early_stopping'] | |
if es == "True": # str, not bool | |
print("Using early stopping.") | |
try: | |
patience = int(config['Model']['patience']) | |
except: | |
patience = 4 | |
callbacks = [EarlyStopping(monitor='val_categorical_accuracy', patience=patience)] | |
epochs = 40 | |
else: | |
epochs = int(config['Model']['epochs']) | |
print("Using {0} epochs".format(epochs)) | |
callbacks = [] | |
hist = model.fit(X_train, Y_train, | |
epochs=epochs, | |
batch_size=int(config['Model']['batch_size']), | |
validation_split=0.2, | |
callbacks=callbacks) | |
#model.evaluate(X_test, Y_test, batch_size=12) | |
make_log(experiment_dir, X_train, X_test, Y_test, model, hist, custom_model) | |
plot_model(model, show_shapes=True, to_file=experiment_dir+"model_diagram.pdf") | |
model_file = experiment_dir + "CNN_event.h5" | |
model.save(model_file) | |
print("Completed run. Wrote results out to ", experiment_dir) | |
if __name__ == '__main__': | |
plac.call(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment