ABC challenge starter kit with Eisen
""" | |
Eisen ABC Challenge starter kit | |
NOTE: you need to register to the challenge, download and unpack the data in | |
order to be able to run the following example. | |
Find more info here: https://abcs.mgh.harvard.edu | |
This is released under MIT license. Do what you want with this code. | |
""" | |
from eisen.datasets import ABCDataset | |
from eisen.models.segmentation import VNet | |
from eisen.io import LoadITKFromFilename | |
from eisen.transforms import ( | |
ResampleITKVolumes, | |
ITKToNumpy, | |
CropCenteredSubVolumes, | |
AddChannelDimension, | |
MapValues, | |
FixedMeanStdNormalization, | |
LabelMapToOneHot, | |
StackImagesChannelwise, | |
FilterFields | |
) | |
from eisen.ops.losses import DiceLoss | |
from eisen.ops.metrics import DiceMetric | |
from eisen.utils import EisenModuleWrapper | |
from eisen.utils.workflows import Training | |
from eisen.utils.logging import LoggingHook | |
from eisen.utils.logging import TensorboardSummaryHook | |
from eisen.utils.artifacts import SaveTorchModelHook | |
from torchvision.transforms import Compose | |
from torch.utils.data import DataLoader | |
from torch.optim import Adam | |
""" | |
Constants defining important parameters of the algorithm. | |
CHANGE HERE WHAT SHOULD BE CHANGED TO FIT YOUR CONFIG. | |
""" | |
# Defining some constants | |
PATH_DATA = './abc_data' # path of data as unpacked from the challenge files | |
PATH_ARTIFACTS = './results' # path for model results | |
TASK = 'task1' | |
NUM_EPOCHS = 100 | |
BATCH_SIZE = 2 | |
VOLUMES_RESOLUTION = [2, 2, 1.5] | |
VOLUMES_PIXEL_SIZE = [128, 128, 128] | |
if TASK == 'task1': | |
n_out_chan = 5 | |
label_field = 'label_task1' | |
else: | |
n_out_chan = 10 | |
label_field = 'label_task2' | |
""" | |
Define Readers and Transforms | |
In order to load data and prepare it for being used by the network, we need to operate | |
I/O operations and define transforms to standardize data. | |
You can add transforms or change the existing ones by editing this | |
""" | |
# readers: for images and labels | |
read_tform = LoadITKFromFilename(['ct', 't1', 't2', label_field], PATH_DATA) | |
# image manipulation transforms | |
resample_tform_img = ResampleITKVolumes( | |
['ct', 't1', 't2'], | |
VOLUMES_RESOLUTION, | |
'linear' | |
) | |
resample_tform_lbl = ResampleITKVolumes( | |
[label_field], | |
VOLUMES_RESOLUTION, | |
'nearest' | |
) | |
to_numpy_tform = ITKToNumpy(['ct', 't1', 't2', label_field]) | |
crop = CropCenteredSubVolumes(fields=['ct', 't1', 't2', label_field], size=VOLUMES_PIXEL_SIZE) | |
map_intensities = MapValues(['t1', 't2'], min_value=0.0, max_value=1.0) | |
normalize_ct = FixedMeanStdNormalization(['ct'], mean=208.0, std=388.0) | |
if TASK == 'task1': | |
map_labels = LabelMapToOneHot([label_field], [1, 2, 3, 4, 5]) | |
else: | |
map_labels = LabelMapToOneHot([label_field], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) | |
stack_modalities = StackImagesChannelwise(['ct', 't1', 't2'], 'image') | |
preserve_only_fields = FilterFields(['image', 'label_task1']) | |
# create a transform to manipulate and load data | |
tform = Compose([ | |
read_tform, | |
resample_tform_img, | |
resample_tform_lbl, | |
to_numpy_tform, | |
crop, | |
map_intensities, | |
normalize_ct, | |
map_labels, | |
stack_modalities, | |
preserve_only_fields | |
]) | |
# create a dataset from the training set of the ABC dataset | |
dataset = ABCDataset( | |
PATH_DATA, | |
training=True, | |
flat_dir_structure=False, # check documentation | |
transform=tform | |
) | |
# Data loader: a pytorch DataLoader is used here to loop through the data as provided by the dataset. | |
data_loader = DataLoader( | |
dataset, | |
batch_size=BATCH_SIZE, | |
shuffle=True, | |
num_workers=4 | |
) | |
""" | |
Building blocks: we define here: | |
* model | |
* loss | |
* metric | |
* optimizer | |
These components are used during training. | |
These blocks will be joined together in a workflow (Eg. training workflow). | |
""" | |
# specify model and loss (building blocks) | |
model = EisenModuleWrapper( | |
module=VNet(input_channels=3, output_channels=n_out_chan), | |
input_names=['image'], | |
output_names=['predictions'] | |
) | |
# CHANGE TASK HERE if needed!! | |
loss = EisenModuleWrapper( | |
module=DiceLoss(dim=[2, 3, 4]), | |
input_names=['predictions', label_field], | |
output_names=['dice_loss'] | |
) | |
# CHANGE TASK HERE if needed!! | |
metric = EisenModuleWrapper( | |
module=DiceMetric(dim=[2, 3, 4]), | |
input_names=['predictions', label_field], | |
output_names=['dice_metric'] | |
) | |
optimizer = Adam(model.parameters(), 0.001) | |
# join all blocks into a workflow (training workflow) | |
training_workflow = Training( | |
model=model, | |
losses=[loss], | |
data_loader=data_loader, | |
optimizer=optimizer, | |
metrics=[metric], | |
gpu=True | |
) | |
# create Hook to monitor training and save models | |
training_loggin_hook = LoggingHook(training_workflow.id, 'Training', PATH_ARTIFACTS) | |
training_summary_hook = TensorboardSummaryHook(training_workflow.id, 'Training', PATH_ARTIFACTS) | |
save_model_hook = SaveTorchModelHook(training_workflow.id, 'Training', PATH_ARTIFACTS) | |
# run optimization for NUM_EPOCHS | |
for i in range(NUM_EPOCHS): | |
training_workflow.run() | |
# todo: VALIDATION and INFERENCE code |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment