Skip to content

Instantly share code, notes, and snippets.

@faustomilletari
Created May 3, 2020 14:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save faustomilletari/107b89a0b46fb817ad61c3228152067b to your computer and use it in GitHub Desktop.
Save faustomilletari/107b89a0b46fb817ad61c3228152067b to your computer and use it in GitHub Desktop.
# You will need to download this data. Exec on your terminal
# wget "https://www.dropbox.com/s/duja4l7ce4ka720/tr_im.nii.gz"
# wget "https://www.dropbox.com/s/0uq78sh2iqzjofv/tr_mask.nii.gz"
# wget "https://www.dropbox.com/s/5amzci0b2sodhpn/val_im.nii.gz"
# IMPORTANT!!!
# You need to have eisen installed from git via pip install --upgrade git+https://github.com/eisen-ai/eisen-core.git
import matplotlib.pyplot as plt
from eisen.datasets import MedSegCovid19
from eisen.models.segmentation import UNet
from eisen.transforms import (
AddChannelDimension,
LabelMapToOneHot,
FixedMeanStdNormalization
)
from eisen.ops.losses import DiceLoss
from eisen.ops.metrics import DiceMetric
from eisen.utils import EisenModuleWrapper
from eisen.utils.workflows import Training, Testing
from eisen.utils.logging.logs import LoggingHook
from eisen.utils.logging import TensorboardSummaryHook
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from torch.optim import Adam
PATH_DATA = './'
PATH_ARTIFACTS = './results'
NUM_EPOCHS = 100
BATCH_SIZE = 2
map_intensities = FixedMeanStdNormalization(['image'], mean=208.0, std=388.0)
add_channel = AddChannelDimension(['image'])
label_to_onehot = LabelMapToOneHot(['label'], classes=[1, 2, 3])
# create a transform to manipulate and load data
tform = Compose([
map_intensities,
add_channel,
label_to_onehot
])
# create a dataset from the training set
dset_train = MedSegCovid19(PATH_DATA, 'tr_im.nii.gz', mask_file='tr_mask.nii.gz', transform=tform)
# create data loader for training, this functionality is pure pytorch
data_loader_train = DataLoader(
dset_train,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4
)
# specify model and loss
model = EisenModuleWrapper(module=UNet(input_channels=1, output_channels=3), input_names=['image'], output_names=['predictions'])
loss = EisenModuleWrapper(module=DiceLoss(dim=[2, 3]), input_names=['predictions', 'label'], output_names=['dice_loss'])
metric = EisenModuleWrapper(module=DiceMetric(dim=[2, 3]), input_names=['predictions', 'label'], output_names=['dice_metric'])
optimizer = Adam(model.parameters(), 0.001)
# join all blocks into a workflow (training workflow)
training = Training(
model=model,
losses=[loss],
data_loader=data_loader_train,
optimizer=optimizer,
metrics=[metric],
gpu=True
)
# shows training logs
train_loggin_hook = LoggingHook(training.id, 'Training', PATH_ARTIFACTS)
# saves summaries
train_summary_hook = TensorboardSummaryHook(training.id, 'Training', PATH_ARTIFACTS)
# run optimization for NUM_EPOCHS
for i in range(NUM_EPOCHS):
training.run()
# INFERENCE -------
tform = Compose([
map_intensities,
add_channel
])
# create a dataset from the test/val set
dset_test = MedSegCovid19(PATH_DATA, 'val_im.nii.gz', mask_file=None, transform=tform)
# create data loader for training, this functionality is pure pytorch
data_loader_test = DataLoader(
dset_test,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=4
)
testing = Testing(
model=model.cuda(),
data_loader=data_loader_test,
metrics=[],
gpu=True
)
testing_summary_hook = TensorboardSummaryHook(testing.id, 'Testing', PATH_ARTIFACTS)
for entry in data_loader_test:
entry['image'] = entry['image'].cuda()
output, _, _ = testing(entry)
predictions = output['predictions'].cpu().data.numpy()
inputs = entry['image'].cpu().data.numpy()
for i in range(BATCH_SIZE):
image = inputs[i][0][..., np.newaxis]
image = (image - np.min(image)) / (np.max(image) - np.min(image)) # normalize image
prediction = predictions[i].transpose((1, 2, 0))
plt.imshow(prediction + image)
plt.show()
@Zhiwei-Zhai
Copy link

Hi, an error is reported when runing this code, at model = EisenModuleWrapper(module=UNet(input_channels=1, output_channels=3), input_names=['image'], output_names=['predictions'])

Traceback (most recent call last):
  File "/exports/radiologie-hpc/zzhai/program/pycharm-2019.3.3/plugins/python/helpers/pydev/pydevd.py", line 1434, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/exports/radiologie-hpc/zzhai/program/pycharm-2019.3.3/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/exports/lkeb-hpc/zzhai/project/work/COVID19_project/Eisen-covid19-MedSeg.py", line 59, in <module>
    model = EisenModuleWrapper(module=UNet(input_channels=1, output_channels=3), input_names=['image'], output_names=['predictions']) 
  File "/home/zzhai/miniconda3/envs/p3monai/lib/python3.6/site-packages/eisen/utils/__init__.py", line 96, in __init__
    self.module = module(*args, **kwargs)
  File "/home/zzhai/miniconda3/envs/p3monai/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'x'

@faustomilletari
Copy link
Author

hello, are you running the latest version of Eisen (the one from the GIT repository?)

if not, please try installing...
pip install --upgrade git+https://github.com/eisen-ai/eisen-core.git

@Zhiwei-Zhai
Copy link

Hi Faustomilletari,
Thanks for your reply. It solved my problem. I had tried pip install --upgrade git+https://github.com/eisen-ai/eisen-core.git, however three erros were reported. Then, I installed 'eisen-core==0.0.5, torch==1.4.0, torchvision==0.5.0', which might introduce the previous error.

ERROR: eisen 0.1.6 has requirement eisen-core==0.0.5, but you'll have eisen-core 0.0.6 which is incompatible.
ERROR: eisen-cli 0.0.5 has requirement torch==1.4.0, but you'll have torch 1.5.0 which is incompatible.
ERROR: eisen-cli 0.0.5 has requirement torchvision==0.5.0, but you'll have torchvision 0.6.0 which is incompatible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment