Skip to content

Instantly share code, notes, and snippets.

@faustomilletari
Created June 2, 2020 13:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save faustomilletari/21bdafe446794fda9cfc140aed904c65 to your computer and use it in GitHub Desktop.
Save faustomilletari/21bdafe446794fda9cfc140aed904c65 to your computer and use it in GitHub Desktop.
"""
Eisen EMIDEC Challenge starter kit
NOTE: you need to register to the challenge, download and unpack the data in
order to be able to run the following example.
Find more info here: http://emidec.com/
This is released under MIT license. Do what you want with this code.
"""
import os
from eisen.datasets import EMIDEC
from eisen.models.segmentation import VNet
from eisen.io import LoadNiftyFromFilename
from eisen.transforms import (
ResampleNiftiVolumes,
NiftiToNumpy,
CropCenteredSubVolumes,
AddChannelDimension,
MapValues,
FixedMeanStdNormalization,
LabelMapToOneHot,
StackImagesChannelwise,
FilterFields
)
from eisen.ops.losses import DiceLoss
from eisen.ops.metrics import DiceMetric
from eisen.utils import EisenModuleWrapper
from eisen.utils.workflows import Training
from eisen.utils.logging import LoggingHook
from eisen.utils.logging import TensorboardSummaryHook
from eisen.utils.artifacts import SaveTorchModelHook
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from torch.optim import Adam
"""
Constants defining important parameters of the algorithm.
CHANGE HERE WHAT SHOULD BE CHANGED TO FIT YOUR CONFIG.
"""
# Defining some constants
PATH_DATA = './emidec_data' # path of data as unpacked from the challenge files
PATH_ARTIFACTS = './results' # path for model results
if not os.path.exists(PATH_ARTIFACTS):
os.mkdir(PATH_ARTIFACTS)
NUM_EPOCHS = 100
BATCH_SIZE = 4
VOLUMES_RESOLUTION = [4, 4, 1] # original emidec data has 1 cubic mm voxel spacing
VOLUMES_PIXEL_SIZE = [64, 64, 16]
CLASSES = [1, 2]
"""
Define Readers and Transforms
In order to load data and prepare it for being used by the network, we need to operate
I/O operations and define transforms to standardize data.
You can add transforms or change the existing ones by editing this
"""
# readers: for images and labels
read_tform = LoadNiftyFromFilename(['image', 'label'], PATH_DATA) # load content of image and data field from dataset
# image manipulation transforms
resample_tform_img = ResampleNiftiVolumes(
['image'],
VOLUMES_RESOLUTION,
'linear'
) # resamples volume
resample_tform_lbl = ResampleNiftiVolumes(
['label'],
VOLUMES_RESOLUTION,
'nearest'
) # resamples labels with nearest interpolations
to_numpy_tform = NiftiToNumpy(['image', 'label']) # converts to numpy
crop = CropCenteredSubVolumes(['image', 'label'], size=VOLUMES_PIXEL_SIZE) # crops images to size
map_intensities = MapValues(['image'], min_value=0.0, max_value=1.0) # maps intensities between 0 and 1
add_channel_dim = AddChannelDimension(['image']) # adds a singleton channel dimension to images
map_labels = LabelMapToOneHot(['label'], CLASSES) # maps labels to 1 class per channel according to CLASSES
preserve_only_fields = FilterFields(['image', 'label']) # filters out all fields in dictionary apart image and label
# create a transform to manipulate and load data
tform = Compose([
read_tform,
resample_tform_img,
resample_tform_lbl,
to_numpy_tform,
crop,
map_intensities,
add_channel_dim,
map_labels,
preserve_only_fields
])
# create a dataset from the training set of the ABC dataset
dataset = EMIDEC(
PATH_DATA,
training=True,
transform=tform # transform is passed here
)
# Data loader: a pytorch DataLoader is used here to loop through the data as provided by the dataset.
data_loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4
)
"""
Building blocks: we define here:
* model
* loss
* metric
* optimizer
These components are used during training.
These blocks will be joined together in a workflow (Eg. training workflow).
"""
# specify model and loss (building blocks)
model = EisenModuleWrapper(
module=VNet(input_channels=1, output_channels=len(CLASSES)),
input_names=['image'], # define that the inputs of the network from the batch are called "image"
output_names=['predictions'] # define that the output of network should be called "prediction"
)
# CHANGE TASK HERE if needed!!
loss = EisenModuleWrapper(
module=DiceLoss(dim=[2, 3, 4]),
input_names=['predictions', 'label'], # define that it is a comparison between prdictions and label
output_names=['dice_loss'] # the output shall be called dice_loss
)
# CHANGE TASK HERE if needed!!
metric = EisenModuleWrapper(
module=DiceMetric(dim=[2, 3, 4]),
input_names=['predictions', 'label'],
output_names=['dice_metric']
)
optimizer = Adam(model.parameters(), 0.001)
# join all blocks into a workflow (training workflow)
training_workflow = Training(
model=model,
losses=[loss],
data_loader=data_loader,
optimizer=optimizer,
metrics=[metric],
gpu=False
)
# create Hook to monitor training and save models
training_loggin_hook = LoggingHook(training_workflow.id, 'Training', PATH_ARTIFACTS)
training_summary_hook = TensorboardSummaryHook(training_workflow.id, 'Training', PATH_ARTIFACTS)
save_model_hook = SaveTorchModelHook(training_workflow.id, 'Training', PATH_ARTIFACTS)
# run optimization for NUM_EPOCHS
for i in range(NUM_EPOCHS):
training_workflow.run()
# todo: VALIDATION and INFERENCE code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment