Created
November 30, 2022 03:43
-
-
Save cmclean/a7e01b916f07955b2693112dcd3edb60 to your computer and use it in GitHub Desktop.
Patch to DOI: 10.5281/zenodo.7154413 to support chronological age training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/ml-based-vcdr/learning/configs/base.py b/ml-based-vcdr/learning/configs/base.py | |
index 904528e..eb8b8f4 100644 | |
--- a/ml-based-vcdr/learning/configs/base.py | |
+++ b/ml-based-vcdr/learning/configs/base.py | |
@@ -92,6 +92,10 @@ def get_config() -> ml_collections.ConfigDict: | |
'staircase': True, | |
}) | |
+ config.callbacks = ml_collections.ConfigDict({ | |
+ 'checkpoint_monitor': 'val_loss', | |
+ }) | |
+ | |
config.outcomes = [ | |
ml_collections.ConfigDict({ | |
'name': 'vertical_cup_to_disc', | |
diff --git a/ml-based-vcdr/learning/configs/eye_age.py b/ml-based-vcdr/learning/configs/eye_age.py | |
new file mode 100644 | |
index 0000000..a254e47 | |
--- /dev/null | |
+++ b/ml-based-vcdr/learning/configs/eye_age.py | |
@@ -0,0 +1,110 @@ | |
+# Copyright 2022 Google LLC. | |
+# | |
+# Redistribution and use in source and binary forms, with or without | |
+# modification, are permitted provided that the following conditions are met: | |
+# | |
+# 1. Redistributions of source code must retain the above copyright notice, this | |
+# list of conditions and the following disclaimer. | |
+# | |
+# 2. Redistributions in binary form must reproduce the above copyright notice, | |
+# this list of conditions and the following disclaimer in the documentation | |
+# and/or other materials provided with the distribution. | |
+# | |
+# 3. Neither the name of the copyright holder nor the names of its contributors | |
+# may be used to endorse or promote products derived from this software | |
+# without specific prior written permission. | |
+# | |
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
+"""Contains the default training configuration.""" | |
+import ml_collections | |
+ | |
+ | |
+def get_config() -> ml_collections.ConfigDict: | |
+ """Returns the defa hyperparameter configuration.""" | |
+ config = ml_collections.ConfigDict() | |
+ | |
+ config.seed = None | |
+ | |
+ # misc. training | |
+ config.train = ml_collections.ConfigDict({ | |
+ 'use_mixed_precision': True, | |
+ 'use_distributed_training': False, | |
+ 'max_num_steps': 400000, | |
+ 'log_step_freq': 283, | |
+ 'fit_verbose': 1, | |
+ 'initial_epoch': 0, | |
+ }) | |
+ | |
+ # dataset and augmentation | |
+ config.dataset = ml_collections.ConfigDict({ | |
+ 'train': '/mnt/disks/data/train/train.tfrecord*', | |
+ 'eval': '/mnt/disks/data/train/eval.tfrecord*', | |
+ 'test': '/mnt/disks/data/train/test.tfrecord*', | |
+ 'predict': '/mnt/disks/data/predict/predict.tfrecord*', | |
+ 'num_train_examples': 217289, | |
+ 'batch_size': 16, | |
+ 'image_size': (587, 587), | |
+ 'random_horizontal_flip': True, | |
+ 'random_vertical_flip': True, | |
+ 'random_brightness_max_delta': 0.1147528, | |
+ 'random_saturation_lower': 0.5597273, | |
+ 'random_saturation_upper': 1.2748845, | |
+ 'random_hue_max_delta': 0.0251488, | |
+ 'random_contrast_lower': 0.9996807, | |
+ 'random_contrast_upper': 1.7704824, | |
+ 'use_cache': False, | |
+ }) | |
+ | |
+ # model architecture | |
+ config.model = ml_collections.ConfigDict({ | |
+ 'backbone': 'inceptionv3', | |
+ 'backbone_drop_rate': 0.8, | |
+ 'input_shape': (587, 587, 3), | |
+ 'weights': 'imagenet', | |
+ 'weight_decay': 0.00004, | |
+ }) | |
+ | |
+ # optimizer | |
+ config.opt = ml_collections.ConfigDict({ | |
+ 'optimizer': 'adam', | |
+ 'learning_rate': 0.001, | |
+ 'beta_1': 0.9, | |
+ 'beta_2': 0.999, | |
+ 'epsilon': 0.1, | |
+ 'amsgrad': False, | |
+ 'use_model_averaging': True, | |
+ 'update_model_averaging_weights': False, | |
+ }) | |
+ | |
+ config.schedule = ml_collections.ConfigDict({ | |
+ 'schedule': 'exponential', | |
+ 'epochs_per_decay': 48, | |
+ 'decay_rate': 0.99, | |
+ 'staircase': True, | |
+ }) | |
+ | |
+ config.callbacks = ml_collections.ConfigDict({ | |
+ 'checkpoint_monitor': 'val_loss', | |
+ }) | |
+ | |
+ config.outcomes = [ | |
+ ml_collections.ConfigDict({ | |
+ 'name': 'age', | |
+ 'type': 'regression', | |
+ 'num_classes': 1, | |
+ 'loss': 'mae', | |
+ 'loss_weight': 1.0, | |
+ }), | |
+ ] | |
+ | |
+ return config | |
+ | |
diff --git a/ml-based-vcdr/learning/metrics.py b/ml-based-vcdr/learning/metrics.py | |
index 5ef4cf2..1e77717 100644 | |
--- a/ml-based-vcdr/learning/metrics.py | |
+++ b/ml-based-vcdr/learning/metrics.py | |
@@ -248,4 +248,7 @@ def get_loss(config: ml_collections.ConfigDict) -> tf.losses.Loss: | |
if loss_name == 'mse': | |
return tf.keras.losses.MeanSquaredError() | |
+ if loss_name == 'mae': | |
+ return tf.keras.losses.MeanAbsoluteError() | |
+ | |
raise ValueError(f'Unknown loss name: {loss_name}') | |
diff --git a/ml-based-vcdr/learning/predict_utils.py b/ml-based-vcdr/learning/predict_utils.py | |
index d40df74..590655a 100644 | |
--- a/ml-based-vcdr/learning/predict_utils.py | |
+++ b/ml-based-vcdr/learning/predict_utils.py | |
@@ -24,8 +24,7 @@ | |
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
-r"""Utilities for generating model predictions.""" | |
-import collections | |
+"""Utilities for generating model predictions.""" | |
from typing import Dict, List | |
import numpy as np | |
@@ -39,7 +38,7 @@ ID_KEY = 'id' | |
# `OUTCOME_COLUMN_MAP['glaucoma_gradability'][1]` corresponds to the second | |
# multi-class target of the model's 'glaucoma_gradability' head. | |
OUTCOME_COLUMN_MAP = { | |
- 'id': { | |
+ ID_KEY: { | |
0: 'image_id' | |
}, | |
'glaucoma_gradability': { | |
@@ -60,7 +59,10 @@ OUTCOME_COLUMN_MAP = { | |
}, | |
'vertical_cup_to_disc': { | |
0: 'vertical_cup_to_disc:VERTICAL_CUP_TO_DISC' | |
- } | |
+ }, | |
+ 'age': { | |
+ 0: 'age:AGE' | |
+ }, | |
} | |
@@ -77,21 +79,28 @@ def generate_predictions( | |
stateful_metrics=None, | |
unit_name='step') | |
- # Build the list of mode outputs. | |
+ # Iterate over the dataset, accumulating model predictions for each outcome. | |
+ # Note: "output_names" is a public attribute of `tf.keras.Model` and denotes a | |
+ # list of string names for model outputs. The order of "output_names" | |
+ # corresponds to the order of output tensors returned by `model()`. | |
output_names = model.output_names.copy() | |
- output_names.append(ID_KEY) | |
- | |
- # Predict outcomes for all examples and build a dictionary of output arrays. | |
- batched_predictions = collections.defaultdict(list) | |
- for (inputs_batch, _, _) in predict_ds: | |
- predict_batch = model.predict_on_batch(inputs_batch) | |
- predict_batch.append(inputs_batch[ID_KEY].numpy()) | |
- for output_name, ndarray in zip(output_names, predict_batch): | |
- batched_predictions[output_name].append(ndarray) | |
+ is_single_headed = len(output_names) == 1 | |
+ predict_dict = {name: [] for name in [ID_KEY] + output_names} | |
+ for batch_input, _, _ in predict_ds: | |
+ predict_dict[ID_KEY].append(batch_input[ID_KEY].numpy()) | |
+ # Note: this check is required since a multi-headed TensorFlow model returns | |
+ # a list of output tensors while a single-headed model returns a single | |
+ # output tensor (rather than a list of size 1). | |
+ model_output = model(batch_input) | |
+ if is_single_headed: | |
+ predict_dict[output_names[0]].append(model_output.numpy()) | |
+ else: | |
+ for name, output_tensor in zip(output_names, model_output): | |
+ predict_dict[name].append(output_tensor.numpy()) | |
progbar.add(1) | |
- print() | |
- return batched_predictions | |
+ print() | |
+ return predict_dict | |
def merge_batched_predictions( | |
diff --git a/ml-based-vcdr/learning/train.py b/ml-based-vcdr/learning/train.py | |
index b2fcc23..3ca7ed5 100644 | |
--- a/ml-based-vcdr/learning/train.py | |
+++ b/ml-based-vcdr/learning/train.py | |
@@ -76,7 +76,7 @@ def get_callbacks( | |
update_weights=update_weights, | |
save_best_only=True, | |
save_weights_only=True, | |
- monitor='val_vertical_cup_to_disc_loss', | |
+ monitor=config.callbacks.checkpoint_monitor, | |
mode='min', | |
save_freq='epoch')) | |
@@ -88,7 +88,7 @@ def get_callbacks( | |
filepath=checkpoint_file, | |
save_best_only=True, | |
save_weights_only=True, | |
- monitor='val_vertical_cup_to_disc_loss', | |
+ monitor=config.callbacks.checkpoint_monitor, | |
mode='min', | |
save_freq='epoch')) | |
@@ -98,7 +98,7 @@ def get_callbacks( | |
filepath=checkpoint_file, | |
save_best_only=True, | |
save_weights_only=True, | |
- monitor='val_vertical_cup_to_disc_loss', | |
+ monitor=config.callbacks.checkpoint_monitor, | |
mode='min', | |
save_freq='epoch')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This gist is licensed under the BSD 3-clause license: https://opensource.org/license/bsd-3-clause/
It was applied to https://zenodo.org/record/7154413 to support publication of Ahadi et al, eLife, 2023: https://doi.org/10.7554/eLife.82364 .