Skip to content

Instantly share code, notes, and snippets.

@cmclean
Created November 30, 2022 03:43
Show Gist options
  • Save cmclean/a7e01b916f07955b2693112dcd3edb60 to your computer and use it in GitHub Desktop.
Save cmclean/a7e01b916f07955b2693112dcd3edb60 to your computer and use it in GitHub Desktop.
Patch to DOI: 10.5281/zenodo.7154413 to support chronological age training
diff --git a/ml-based-vcdr/learning/configs/base.py b/ml-based-vcdr/learning/configs/base.py
index 904528e..eb8b8f4 100644
--- a/ml-based-vcdr/learning/configs/base.py
+++ b/ml-based-vcdr/learning/configs/base.py
@@ -92,6 +92,10 @@ def get_config() -> ml_collections.ConfigDict:
'staircase': True,
})
+ config.callbacks = ml_collections.ConfigDict({
+ 'checkpoint_monitor': 'val_loss',
+ })
+
config.outcomes = [
ml_collections.ConfigDict({
'name': 'vertical_cup_to_disc',
diff --git a/ml-based-vcdr/learning/configs/eye_age.py b/ml-based-vcdr/learning/configs/eye_age.py
new file mode 100644
index 0000000..a254e47
--- /dev/null
+++ b/ml-based-vcdr/learning/configs/eye_age.py
@@ -0,0 +1,110 @@
+# Copyright 2022 Google LLC.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its contributors
+# may be used to endorse or promote products derived from this software
+# without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Contains the default training configuration."""
+import ml_collections
+
+
+def get_config() -> ml_collections.ConfigDict:
+ """Returns the defa hyperparameter configuration."""
+ config = ml_collections.ConfigDict()
+
+ config.seed = None
+
+ # misc. training
+ config.train = ml_collections.ConfigDict({
+ 'use_mixed_precision': True,
+ 'use_distributed_training': False,
+ 'max_num_steps': 400000,
+ 'log_step_freq': 283,
+ 'fit_verbose': 1,
+ 'initial_epoch': 0,
+ })
+
+ # dataset and augmentation
+ config.dataset = ml_collections.ConfigDict({
+ 'train': '/mnt/disks/data/train/train.tfrecord*',
+ 'eval': '/mnt/disks/data/train/eval.tfrecord*',
+ 'test': '/mnt/disks/data/train/test.tfrecord*',
+ 'predict': '/mnt/disks/data/predict/predict.tfrecord*',
+ 'num_train_examples': 217289,
+ 'batch_size': 16,
+ 'image_size': (587, 587),
+ 'random_horizontal_flip': True,
+ 'random_vertical_flip': True,
+ 'random_brightness_max_delta': 0.1147528,
+ 'random_saturation_lower': 0.5597273,
+ 'random_saturation_upper': 1.2748845,
+ 'random_hue_max_delta': 0.0251488,
+ 'random_contrast_lower': 0.9996807,
+ 'random_contrast_upper': 1.7704824,
+ 'use_cache': False,
+ })
+
+ # model architecture
+ config.model = ml_collections.ConfigDict({
+ 'backbone': 'inceptionv3',
+ 'backbone_drop_rate': 0.8,
+ 'input_shape': (587, 587, 3),
+ 'weights': 'imagenet',
+ 'weight_decay': 0.00004,
+ })
+
+ # optimizer
+ config.opt = ml_collections.ConfigDict({
+ 'optimizer': 'adam',
+ 'learning_rate': 0.001,
+ 'beta_1': 0.9,
+ 'beta_2': 0.999,
+ 'epsilon': 0.1,
+ 'amsgrad': False,
+ 'use_model_averaging': True,
+ 'update_model_averaging_weights': False,
+ })
+
+ config.schedule = ml_collections.ConfigDict({
+ 'schedule': 'exponential',
+ 'epochs_per_decay': 48,
+ 'decay_rate': 0.99,
+ 'staircase': True,
+ })
+
+ config.callbacks = ml_collections.ConfigDict({
+ 'checkpoint_monitor': 'val_loss',
+ })
+
+ config.outcomes = [
+ ml_collections.ConfigDict({
+ 'name': 'age',
+ 'type': 'regression',
+ 'num_classes': 1,
+ 'loss': 'mae',
+ 'loss_weight': 1.0,
+ }),
+ ]
+
+ return config
+
diff --git a/ml-based-vcdr/learning/metrics.py b/ml-based-vcdr/learning/metrics.py
index 5ef4cf2..1e77717 100644
--- a/ml-based-vcdr/learning/metrics.py
+++ b/ml-based-vcdr/learning/metrics.py
@@ -248,4 +248,7 @@ def get_loss(config: ml_collections.ConfigDict) -> tf.losses.Loss:
if loss_name == 'mse':
return tf.keras.losses.MeanSquaredError()
+ if loss_name == 'mae':
+ return tf.keras.losses.MeanAbsoluteError()
+
raise ValueError(f'Unknown loss name: {loss_name}')
diff --git a/ml-based-vcdr/learning/predict_utils.py b/ml-based-vcdr/learning/predict_utils.py
index d40df74..590655a 100644
--- a/ml-based-vcdr/learning/predict_utils.py
+++ b/ml-based-vcdr/learning/predict_utils.py
@@ -24,8 +24,7 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-r"""Utilities for generating model predictions."""
-import collections
+"""Utilities for generating model predictions."""
from typing import Dict, List
import numpy as np
@@ -39,7 +38,7 @@ ID_KEY = 'id'
# `OUTCOME_COLUMN_MAP['glaucoma_gradability'][1]` corresponds to the second
# multi-class target of the model's 'glaucoma_gradability' head.
OUTCOME_COLUMN_MAP = {
- 'id': {
+ ID_KEY: {
0: 'image_id'
},
'glaucoma_gradability': {
@@ -60,7 +59,10 @@ OUTCOME_COLUMN_MAP = {
},
'vertical_cup_to_disc': {
0: 'vertical_cup_to_disc:VERTICAL_CUP_TO_DISC'
- }
+ },
+ 'age': {
+ 0: 'age:AGE'
+ },
}
@@ -77,21 +79,28 @@ def generate_predictions(
stateful_metrics=None,
unit_name='step')
- # Build the list of mode outputs.
+ # Iterate over the dataset, accumulating model predictions for each outcome.
+ # Note: "output_names" is a public attribute of `tf.keras.Model` and denotes a
+ # list of string names for model outputs. The order of "output_names"
+ # corresponds to the order of output tensors returned by `model()`.
output_names = model.output_names.copy()
- output_names.append(ID_KEY)
-
- # Predict outcomes for all examples and build a dictionary of output arrays.
- batched_predictions = collections.defaultdict(list)
- for (inputs_batch, _, _) in predict_ds:
- predict_batch = model.predict_on_batch(inputs_batch)
- predict_batch.append(inputs_batch[ID_KEY].numpy())
- for output_name, ndarray in zip(output_names, predict_batch):
- batched_predictions[output_name].append(ndarray)
+ is_single_headed = len(output_names) == 1
+ predict_dict = {name: [] for name in [ID_KEY] + output_names}
+ for batch_input, _, _ in predict_ds:
+ predict_dict[ID_KEY].append(batch_input[ID_KEY].numpy())
+ # Note: this check is required since a multi-headed TensorFlow model returns
+ # a list of output tensors while a single-headed model returns a single
+ # output tensor (rather than a list of size 1).
+ model_output = model(batch_input)
+ if is_single_headed:
+ predict_dict[output_names[0]].append(model_output.numpy())
+ else:
+ for name, output_tensor in zip(output_names, model_output):
+ predict_dict[name].append(output_tensor.numpy())
progbar.add(1)
- print()
- return batched_predictions
+ print()
+ return predict_dict
def merge_batched_predictions(
diff --git a/ml-based-vcdr/learning/train.py b/ml-based-vcdr/learning/train.py
index b2fcc23..3ca7ed5 100644
--- a/ml-based-vcdr/learning/train.py
+++ b/ml-based-vcdr/learning/train.py
@@ -76,7 +76,7 @@ def get_callbacks(
update_weights=update_weights,
save_best_only=True,
save_weights_only=True,
- monitor='val_vertical_cup_to_disc_loss',
+ monitor=config.callbacks.checkpoint_monitor,
mode='min',
save_freq='epoch'))
@@ -88,7 +88,7 @@ def get_callbacks(
filepath=checkpoint_file,
save_best_only=True,
save_weights_only=True,
- monitor='val_vertical_cup_to_disc_loss',
+ monitor=config.callbacks.checkpoint_monitor,
mode='min',
save_freq='epoch'))
@@ -98,7 +98,7 @@ def get_callbacks(
filepath=checkpoint_file,
save_best_only=True,
save_weights_only=True,
- monitor='val_vertical_cup_to_disc_loss',
+ monitor=config.callbacks.checkpoint_monitor,
mode='min',
save_freq='epoch'))
@cmclean
Copy link
Author

cmclean commented Mar 31, 2023

This gist is licensed under the BSD 3-clause license: https://opensource.org/license/bsd-3-clause/

It was applied to https://zenodo.org/record/7154413 to support publication of Ahadi et al, eLife, 2023: https://doi.org/10.7554/eLife.82364 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment