Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
ABC challenge starter kit with Eisen
"""
Eisen ABC Challenge starter kit
NOTE: you need to register to the challenge, download and unpack the data in
order to be able to run the following example.
Find more info here: https://abcs.mgh.harvard.edu
This is released under MIT license. Do what you want with this code.
"""
from eisen.datasets import ABCDataset
from eisen.models.segmentation import VNet
from eisen.io import LoadITKFromFilename
from eisen.transforms import (
ResampleITKVolumes,
ITKToNumpy,
CropCenteredSubVolumes,
AddChannelDimension,
MapValues,
FixedMeanStdNormalization,
LabelMapToOneHot,
StackImagesChannelwise,
FilterFields
)
from eisen.ops.losses import DiceLoss
from eisen.ops.metrics import DiceMetric
from eisen.utils import EisenModuleWrapper
from eisen.utils.workflows import Training
from eisen.utils.logging import LoggingHook
from eisen.utils.logging import TensorboardSummaryHook
from eisen.utils.artifacts import SaveTorchModelHook
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from torch.optim import Adam
"""
Constants defining important parameters of the algorithm.
CHANGE HERE WHAT SHOULD BE CHANGED TO FIT YOUR CONFIG.
"""
# Defining some constants
PATH_DATA = './abc_data' # path of data as unpacked from the challenge files
PATH_ARTIFACTS = './results' # path for model results
TASK = 'task1'
NUM_EPOCHS = 100
BATCH_SIZE = 2
VOLUMES_RESOLUTION = [2, 2, 1.5]
VOLUMES_PIXEL_SIZE = [128, 128, 128]
if TASK == 'task1':
n_out_chan = 5
label_field = 'label_task1'
else:
n_out_chan = 10
label_field = 'label_task2'
"""
Define Readers and Transforms
In order to load data and prepare it for being used by the network, we need to operate
I/O operations and define transforms to standardize data.
You can add transforms or change the existing ones by editing this
"""
# readers: for images and labels
read_tform = LoadITKFromFilename(['ct', 't1', 't2', label_field], PATH_DATA)
# image manipulation transforms
resample_tform_img = ResampleITKVolumes(
['ct', 't1', 't2'],
VOLUMES_RESOLUTION,
'linear'
)
resample_tform_lbl = ResampleITKVolumes(
[label_field],
VOLUMES_RESOLUTION,
'nearest'
)
to_numpy_tform = ITKToNumpy(['ct', 't1', 't2', label_field])
crop = CropCenteredSubVolumes(fields=['ct', 't1', 't2', label_field], size=VOLUMES_PIXEL_SIZE)
map_intensities = MapValues(['t1', 't2'], min_value=0.0, max_value=1.0)
normalize_ct = FixedMeanStdNormalization(['ct'], mean=208.0, std=388.0)
if TASK == 'task1':
map_labels = LabelMapToOneHot([label_field], [1, 2, 3, 4, 5])
else:
map_labels = LabelMapToOneHot([label_field], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
stack_modalities = StackImagesChannelwise(['ct', 't1', 't2'], 'image')
preserve_only_fields = FilterFields(['image', 'label_task1'])
# create a transform to manipulate and load data
tform = Compose([
read_tform,
resample_tform_img,
resample_tform_lbl,
to_numpy_tform,
crop,
map_intensities,
normalize_ct,
map_labels,
stack_modalities,
preserve_only_fields
])
# create a dataset from the training set of the ABC dataset
dataset = ABCDataset(
PATH_DATA,
training=True,
flat_dir_structure=False, # check documentation
transform=tform
)
# Data loader: a pytorch DataLoader is used here to loop through the data as provided by the dataset.
data_loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4
)
"""
Building blocks: we define here:
* model
* loss
* metric
* optimizer
These components are used during training.
These blocks will be joined together in a workflow (Eg. training workflow).
"""
# specify model and loss (building blocks)
model = EisenModuleWrapper(
module=VNet(input_channels=3, output_channels=n_out_chan),
input_names=['image'],
output_names=['predictions']
)
# CHANGE TASK HERE if needed!!
loss = EisenModuleWrapper(
module=DiceLoss(dim=[2, 3, 4]),
input_names=['predictions', label_field],
output_names=['dice_loss']
)
# CHANGE TASK HERE if needed!!
metric = EisenModuleWrapper(
module=DiceMetric(dim=[2, 3, 4]),
input_names=['predictions', label_field],
output_names=['dice_metric']
)
optimizer = Adam(model.parameters(), 0.001)
# join all blocks into a workflow (training workflow)
training_workflow = Training(
model=model,
losses=[loss],
data_loader=data_loader,
optimizer=optimizer,
metrics=[metric],
gpu=True
)
# create Hook to monitor training and save models
training_loggin_hook = LoggingHook(training_workflow.id, 'Training', PATH_ARTIFACTS)
training_summary_hook = TensorboardSummaryHook(training_workflow.id, 'Training', PATH_ARTIFACTS)
save_model_hook = SaveTorchModelHook(training_workflow.id, 'Training', PATH_ARTIFACTS)
# run optimization for NUM_EPOCHS
for i in range(NUM_EPOCHS):
training_workflow.run()
# todo: VALIDATION and INFERENCE code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment