Created
June 9, 2020 08:14
-
-
Save faustomilletari/1c1d9d671641e36e63199d26bb232d58 to your computer and use it in GitHub Desktop.
Covid19Challenge.eu Eisen starter kit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Eisen EU COVID-19 challenge starter kit | |
NOTE: you need to register to the challenge, download and unpack the data in | |
order to be able to run the following example. | |
Find more info here: https://www.covid19challenge.eu | |
Information about Eisen can be found at http://eisen.ai -- Join the community on Slack https://bit.ly/2L7i6OL | |
This is released under MIT license. Do what you want with this code. | |
""" | |
import os | |
from eisen.datasets import JsonDataset | |
from eisen.models.segmentation import VNet | |
from eisen.io import LoadNiftiFromFilename | |
from eisen.transforms import ( | |
ResampleNiftiVolumes, | |
NiftiToNumpy, | |
CropCenteredSubVolumes, | |
FixedMeanStdNormalization, | |
AddChannelDimension, | |
MapValues, | |
FilterFields, | |
LabelMapToOneHot | |
) | |
from eisen.ops.losses import DiceLoss | |
from eisen.ops.metrics import DiceMetric | |
from eisen.utils import EisenModuleWrapper | |
from eisen.utils.workflows import Training | |
from eisen.utils.logging import LoggingHook | |
from eisen.utils.logging import TensorboardSummaryHook | |
from eisen.utils.artifacts import SaveTorchModelHook | |
from torchvision.transforms import Compose | |
from torch.utils.data import DataLoader | |
from torch.optim import Adam | |
# ### SEGMENTATION TASK ############################################################################################### | |
# | |
# This code is meant to provide an example on how to train a DL network on https://www.covid19challenge.eu data. | |
# | |
# We do not guarantee that the results of this particular neural network will be optimal compared to state of the art, | |
# on the other hand, the code included here is highly flexible and allows for several user-made changes. Feel free | |
# to experiment with this code on your local hardware or on Colab. | |
# | |
# It is also possible to use this dataset on our computational platform hosted online (TBA). In this case, computation | |
# will happen using the resources of covid19challenge.eu. Check out the website to know how to experiments with | |
# our data using the remote resources we provisioned for the challenge participants. | |
# | |
""" | |
Constants defining important parameters of the algorithm. | |
CHANGE HERE WHAT SHOULD BE CHANGED TO FIT YOUR EXPERIMENT NEEDS. | |
>>> IMPORTANT!! | |
This code will save Tensorboard summaries, model snapshots and print output on the console. | |
You can watch the progress of your training job by pointing a tensorboard process to the output folder. | |
""" | |
# Defining some constants | |
PATH_DATA = './' # path of data as unpacked from the challenge files | |
PATH_ARTIFACTS = './results' # path for model results | |
os.makedirs(PATH_ARTIFACTS, exist_ok=True) | |
USE_GPU = True | |
TRAINING = True | |
NUM_EPOCHS = 100 | |
BATCH_SIZE = 2 | |
VOLUMES_RESOLUTION = [4, 4, 2] | |
VOLUMES_PIXEL_SIZE = [128, 128, 128] | |
LABELS = [1, 2, 3, 4, 5, 6, 7] | |
INPUT_CHANNELS = 1 # CT Data | |
OUTPUT_CHANNELS = len(LABELS) # different label set can be achieved by transforming the labels | |
""" | |
Define Readers and Transforms | |
In order to load data and prepare it for being used by the network, we need to operate | |
I/O operations and define transforms to standardize data. | |
You can add transforms or change the existing ones by editing this | |
""" | |
# readers: for images and labels | |
read_tform = LoadNiftiFromFilename(['image', 'label'], PATH_DATA) | |
# Image manipulation transforms. Here we declare components of the transform chain | |
# we want to resample images to a common resolution so that they are all comparable and each pixel has | |
# the same physical meaning in terms of millimeters | |
resample_tform_img = ResampleNiftiVolumes( | |
['image'], | |
VOLUMES_RESOLUTION, | |
'linear' | |
) | |
# the labels are interpolated with 'nearest' because they are discrete | |
# and we should not create weird interpolation artifacts | |
resample_tform_lbl = ResampleNiftiVolumes( | |
['label'], | |
VOLUMES_RESOLUTION, | |
'nearest' | |
) | |
# We bring the data from Nifti to numpy so we can work further | |
to_numpy_tform = NiftiToNumpy(['image', 'label']) | |
# Cropping the resampled images to have the same pixel size | |
crop = CropCenteredSubVolumes(fields=['image', 'label'], size=VOLUMES_PIXEL_SIZE) | |
# normalization of intensities. here there might be more than one valid choice on the method to accomplish this | |
# alternative transform for normalization is commented below | |
# normalize_ct = FixedMeanStdNormalization(['image'], mean=208.0, std=388.0) | |
normalize_ct = MapValues(['image'], min_value=0.0, max_value=1.0) | |
# add a channel dimension to the data so that the image is 4-D with a single channel (required by the network) | |
add_channel = AddChannelDimension(['image']) | |
# labels are integers, but can be mapped to a 1-hot-encoding to be used during learning | |
map_labels = LabelMapToOneHot(['label'], LABELS) | |
# various transforms have created a lot of information. we keep only 'image' and 'label' because in this | |
# case they are the only thing we need to train | |
preserve_only_fields = FilterFields(['image', 'label']) | |
# create a transform to manipulate and load data | |
tform = Compose([ | |
read_tform, | |
resample_tform_img, | |
resample_tform_lbl, | |
to_numpy_tform, | |
crop, | |
normalize_ct, | |
add_channel, | |
map_labels, | |
preserve_only_fields | |
]) | |
# WARNING: you can actually split the json content and create 2 files one for training and the other for validation. | |
# There are also other ways to accomplish the same using PyTorch samplers. | |
# In this example we skip validation and we create a dataset from the training set only. | |
training_dataset = JsonDataset( | |
PATH_DATA, | |
json_file='dataset.json', | |
transform=tform | |
) | |
# Data loader: a pytorch DataLoader is used here to loop through the data as provided by the dataset. | |
data_loader = DataLoader( | |
training_dataset, | |
batch_size=BATCH_SIZE, | |
shuffle=True, | |
num_workers=4 | |
) | |
""" | |
Building blocks: we define here: | |
* model | |
* loss | |
* metric | |
* optimizer | |
These components are used during training. | |
These blocks will be joined together in a workflow (Eg. training workflow). | |
""" | |
# specify model and loss (building blocks) | |
model = EisenModuleWrapper( | |
module=VNet(input_channels=INPUT_CHANNELS, output_channels=OUTPUT_CHANNELS), | |
input_names=['image'], | |
output_names=['predictions'] | |
) | |
loss = EisenModuleWrapper( | |
module=DiceLoss(dim=[2, 3, 4]), | |
input_names=['predictions', 'label'], | |
output_names=['dice_loss'] | |
) | |
metric = EisenModuleWrapper( | |
module=DiceMetric(dim=[2, 3, 4]), | |
input_names=['predictions', 'label'], | |
output_names=['dice_metric'] | |
) | |
optimizer = Adam(model.parameters(), 0.001) | |
# join all blocks into a workflow (training workflow) | |
training_workflow = Training( | |
model=model, | |
losses=[loss], | |
data_loader=data_loader, | |
optimizer=optimizer, | |
metrics=[metric], | |
gpu=USE_GPU | |
) | |
# create Hook to monitor training and save models | |
training_loggin_hook = LoggingHook(training_workflow.id, 'Training', PATH_ARTIFACTS) | |
# create Hook to automatically populate a tensorboard and get summaries | |
training_summary_hook = TensorboardSummaryHook(training_workflow.id, 'Training', PATH_ARTIFACTS, show_all_axes=True) | |
# create Hook to automatically save the model | |
save_model_hook = SaveTorchModelHook(training_workflow.id, 'Training', PATH_ARTIFACTS) | |
# run optimization for NUM_EPOCHS | |
for i in range(NUM_EPOCHS): | |
training_workflow.run() | |
# todo: VALIDATION and INFERENCE code (remember to split the data into train/validation set) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment