Skip to content

Instantly share code, notes, and snippets.

@jcreinhold
Created March 6, 2021 17:07
Show Gist options
  • Save jcreinhold/c9bfeb15ec4f768cd6af20e1662da8fc to your computer and use it in GitHub Desktop.
Save jcreinhold/c9bfeb15ec4f768cd6af20e1662da8fc to your computer and use it in GitHub Desktop.
Neural network (3D Tiramisu) for FLAIR-based T2-lesion segmentation
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
3D Tiramisu network for FLAIR-based T2-lesion segmentation
This code is unfortunately a huge mess. But, given the CSV files with
the appropriate setup, you can run the below command (starting with
"python -u ...") to generate the network used to generate the
segmentation results in the paper:
"A Structural Causal Model for MR Images of Multiple Sclerosis"
https://arxiv.org/abs/2103.03158.
This script requires:
1. msseg (https://github.com/jcreinhold/msseg)
2. lesionqc (https://github.com/jcreinhold/lesionqc)
and the various other packages listed in the imports.
The CSV file is setup with "flair" and "t1" as headers and then
the full path to the NIfTI files (to the corresponding FLAIR and T1-w
images) are the rows. We used the ISBI 2015 and MICCAI 2016 Challenge
Data, as well as some private labeled data, to train the network.
python -u tiramisu3d_only_flair.py \
--train-csv csv/all/train_weighted.csv \
--valid-csv csv/all/valid_weighted.csv \
-ic 1 \
--use-multitask \
--use-mixup \
--use-aug \
-db 4 4 4 4 \
-ub 4 4 4 4 \
-gr 12 \
-bnl 4 \
-lr 0.0002 \
-bt 0.8 0.99 \
-bs 7 \
-ps 64 64 64 \
-vbs 7 \
-vps 96 96 96 \
-da 25 \
-ne 100 \
-dr 0.1 \
-wd 0.000001 \
-ma 0.4 \
-mm 0.8 \
-spv 10 \
-v
Author: Jacob Reinhold (jacob.reinhold@jhu.edu)
Created on: Feb 28, 2021
"""
from typing import *
from argparse import ArgumentParser
import contextlib
from functools import partial
import logging
import os
from os.path import join
import sys
import warnings
import nibabel as nib
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW, RMSprop, lr_scheduler
from torch.utils.data import DataLoader
from scipy.ndimage.morphology import binary_fill_holes, generate_binary_structure
from skimage.morphology import remove_small_objects
from skimage.segmentation import (
inverse_gaussian_gradient,
morphological_chan_vese,
morphological_geodesic_active_contour
)
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.metrics.metric import NumpyMetric
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.parsing import AttributeDict
with open(os.devnull, "w") as f:
with contextlib.redirect_stdout(f):
import torchio
from torchio.transforms import (
Compose,
CropOrPad,
OneOf,
RandomAffine,
RandomElasticDeformation
)
from msseg.experiment.lightningtiramisu import LightningTiramisu
from msseg.model.tiramisu import Tiramisu3d
from msseg.cutmix import cutmix3d
from msseg.data import csv_to_subjectlist
from msseg.loss import binary_combo_loss
from msseg.util import n_dirname
from lesionqc import isbi15_score, corr
############################## Configuration ###################################
def arg_parser():
parser = ArgumentParser(description='train/test a 3D Tiramisu model')
required = parser.add_argument_group('Required')
required.add_argument('--train-csv', type=str, default=None,
help='path to csv with training images')
required.add_argument('--valid-csv', type=str, default=None,
help='path to csv with validation images')
required.add_argument('--test-csv', type=str, default=None,
help='path to csv with test images')
required.add_argument('--trained-model-path', type=str, default=None,
help='path to output the trained model')
required.add_argument('--out-path', type=str, default=None,
help='path to output the images in testing')
options = parser.add_argument_group('Options')
options.add_argument('-bs', '--batch-size', type=int, default=2,
help='training/test batch size [Default=2]')
options.add_argument('-vbs', '--valid-batch-size', type=int, default=2,
help='validation batch size [Default=2]')
options.add_argument('-hs', '--head-size', type=int, default=48,
help='size of head (for multi-task) [Default=48]')
options.add_argument('-lw', '--loss-weight', type=float, default=0.6,
help='weight of positive class in combo loss [Default=0.6]')
options.add_argument('-mg', '--multigpu', action='store_true', default=False,
help='use multiple gpus [Default=False]')
options.add_argument('-nw', '--num-workers', type=int, default=16,
help='number of CPU processors to use [Default=16]')
options.add_argument('-ps', '--patch-size', type=int, nargs=3, default=(96,96,96),
help='training/test patch size extracted from image [Default=96^3]')
options.add_argument('-vps', '--valid-patch-size', type=int, nargs=3, default=(128,128,128),
help='validation patch size extracted from image [Default=128^3]')
options.add_argument('-ql', '--queue-length', type=int, default=200,
help='queue length for torchio sampler [Default=200]')
options.add_argument('-rs', '--resume', type=str, default=None,
help='resume from this path [Default=None]')
options.add_argument('-sd', '--seed', type=int, default=0,
help='set seed for reproducibility [Default=0]')
options.add_argument('-spv', '--samples-per-volume', type=int, default=10,
help='samples per volume for torchio sampler [Default=10]')
options.add_argument('-th', '--threshold', type=float, default=0.5,
help='prob. threshold for seg [Default=0.5]')
options.add_argument('-uls', '--use-label-sampler', action='store_true',
help="use label sampler instead of uniform [Default=False]")
options.add_argument('-v', '--verbosity', action="count", default=0,
help="increase output verbosity (e.g., -vv is more than -v)")
train_options = parser.add_argument_group('Training Options')
train_options.add_argument('-bt', '--betas', type=float, default=(0.9,0.999), nargs=2,
help='adamw momentum parameters (or RMSprop momentum and alpha params)'
' [Default=(0.9,0.999)]')
train_options.add_argument('-da', '--decay-after', type=int, default=8,
help='decay learning rate after this number of epochs [Default=8]')
train_options.add_argument('-iw', '--isbi15score-weight', type=float, default=1.,
help='weight for isbi15 score in isbi15_score_minus_loss'
'(1. is equal weighting) [Default=1.]')
train_options.add_argument('-lr', '--learning-rate', type=float, default=3e-4,
help='learning rate for the optimizer [Default=3e-4]')
train_options.add_argument('-lf', '--loss-function', type=str, default='combo',
choices=['combo','l1','mse'],
help='loss function to train the network [Default=combo]')
train_options.add_argument('-ne', '--n-epochs', type=int, default=64,
help='number of epochs [Default=64]')
train_options.add_argument('-ur', '--use-rmsprop', action='store_true',
help="use rmsprop instead of adam [Default=False]")
train_options.add_argument('-uw', '--use-weight', action='store_true',
help="use the weight field in the csv to weight subjects [Default=False]")
train_options.add_argument('-wd', '--weight-decay', type=float, default=1e-5,
help="weight decay parameter for adamw [Default=1e-5]")
train_options.add_argument('-sm', '--softmask', action='store_true',
help="use softmasks for training [Default=False]")
train_options.add_argument('--syn-weight', type=float, default=0.1,
help='weight of synthesis objective')
nn_options = parser.add_argument_group('Neural Network Options')
nn_options.add_argument('-ic', '--in-channels', type=int, default=1,
help='number of input channels [Default=1]')
nn_options.add_argument('-oc', '--out-channels', type=int, default=1,
help='number of output channels [Default=1]')
nn_options.add_argument('-dr', '--dropout-rate', type=float, default=0.1,
help='dropout probability [Default=0.1]')
nn_options.add_argument('-psd', '--p-shakedrop', type=float, default=0.,
help='shakedrop max probability (according to linear decay rule) [Default=0.]')
nn_options.add_argument('-in', '--init-type', type=str, default='he_uniform',
choices=('normal', 'xavier_normal', 'he_normal', 'he_uniform', 'orthogonal'),
help='use this type of initialization for the network [Default=he_uniform]')
nn_options.add_argument('-ing', '--init-gain', type=float, default=0.2,
help='use this initialization gain for initialization [Default=0.2]')
nn_options.add_argument('-db', '--down-blocks', type=int, default=(4,4,4,4,4), nargs='+',
help='tiramisu down block specification [Default=(4,4,4,4,4)]')
nn_options.add_argument('-ub', '--up-blocks', type=int, default=(4,4,4,4,4), nargs='+',
help='tiramisu up block specification [Default=(4,4,4,4,4)]')
nn_options.add_argument('-bnl', '--bottleneck-layers', type=int, default=4,
help='tiramisu bottleneck specification [Default=4]')
nn_options.add_argument('-gr', '--growth-rate', type=int, default=12,
help='tiramisu growth rate specification [Default=12]')
nn_options.add_argument('-ocfc', '--out-chans-first-conv', type=int, default=48,
help='tiramisu output channel size of first conv specification [Default=48]')
aug_options = parser.add_argument_group('Data Augmentation Options')
aug_options.add_argument('--use-pd', action='store_true', default=False,
help='use PD contrast as input [Default=False]')
aug_options.add_argument('--use-aug', action='store_true', default=False,
help='use data augmentation [Default=False]')
aug_options.add_argument('--use-mixup', action='store_true', default=False,
help='use mixup [Default=False]')
aug_options.add_argument('-ma', '--mixup-alpha', type=float, default=0.4,
help='mixup alpha parameter for beta dist. [Default=0.4]')
aug_options.add_argument('-mm', '--mixup-margin', type=float, default=0.,
help='mixup margin for asymmetric mixup (0=>off) [Default=0.]')
aug_options.add_argument('--mix-to-one', action='store_true', default=False,
help='mixup/cutmix a batch of 2 to 1 [Default=False]')
aug_options.add_argument('--use-cutmix', action='store_true', default=False,
help='use cutmix [Default=False]')
aug_options.add_argument('--use-multitask', action='store_true', default=False,
help='use multitask objective [Default=False]')
post_options = parser.add_argument_group('Post-processing Options')
post_options.add_argument('-mls', '--min-lesion-size', type=int, default=3,
help='in testing, removes lesion smaller in voxels than this value [Default=3]')
post_options.add_argument('-bhf', '--binary-hole-fill', action='store_true', default=False,
help='in testing, preform binary hole filling')
post_options.add_argument('--use-morph-acwe', action='store_true', default=False,
help='use morphological chan-vese [Default=False]')
post_options.add_argument('--use-morph-gac', action='store_true', default=False,
help='use morphological geodesic active contour [Default=False]')
return parser
def create_exp_config(args):
exp_config = AttributeDict(
data_params=dict(
train_csv = args.train_csv,
valid_csv = args.valid_csv,
test_csv = args.test_csv,
trained_model_path = args.trained_model_path,
batch_size = args.batch_size,
valid_batch_size = args.valid_batch_size,
num_workers = args.num_workers,
patch_size = args.patch_size,
valid_patch_size = args.valid_patch_size,
queue_length = args.queue_length,
samples_per_volume = args.samples_per_volume,
softmask = args.softmask,
threshold = args.threshold,
use_aug = args.use_aug,
use_cutmix = args.use_cutmix,
use_label_sampler = args.use_label_sampler,
use_pd = args.use_pd,
use_mixup = args.use_mixup,
mixup_alpha = args.mixup_alpha,
mixup_margin = args.mixup_margin,
mix_to_one = args.mix_to_one,
use_multitask = args.use_multitask,
use_weight = args.use_weight,
),
lightning_params=dict(
init_params = dict(
init_type=args.init_type,
gain=args.init_gain
),
decay_after = args.decay_after,
multigpu = args.multigpu,
n_epochs = args.n_epochs,
network_dim = 3,
resume = args.resume,
seed = args.seed
),
network_params=dict(
in_channels = args.in_channels,
out_channels = args.out_channels,
down_blocks = args.down_blocks,
up_blocks = args.up_blocks,
bottleneck_layers = args.bottleneck_layers,
growth_rate = args.growth_rate,
out_chans_first_conv = args.out_chans_first_conv,
dropout_rate = args.dropout_rate,
p_shakedrop = args.p_shakedrop
),
optim_params=dict(
lr = args.learning_rate,
betas = args.betas,
weight_decay = args.weight_decay,
),
misc_params=dict(
head_size = args.head_size,
isbi15score_weight = args.isbi15score_weight,
loss_function = args.loss_function,
loss_weight = args.loss_weight,
syn_weight = args.syn_weight,
use_rmsprop = args.use_rmsprop
)
)
return exp_config
################################# Network ######################################
class ISBIScore(NumpyMetric):
def forward(self, y_hat, y):
y_, y_hat_ = y.squeeze(), y_hat.squeeze()
if y_.ndim == 3 and y_hat_.ndim == 3: # batch size 1
isbiscore = isbi15_score(y_hat_, y_)
elif y_.ndim == 4 and y_hat_.ndim == 4:
isbiscore = 0.
for y_i, y_hat_i in zip(y_, y_hat_):
isbiscore += isbi15_score(y_hat_i, y_i, reweighted=False)
dims = (1,2,3)
pred_vols = y_hat_.sum(axis=dims)
true_vols = y_.sum(axis=dims)
vol_corr = corr(pred_vols, true_vols)
isbiscore += vol_corr / 4
isbiscore /= y_.shape[0]
else:
raise ValueError(f'y ndim={y_.ndim}; y_hat ndim={y_hat_.ndim} not valid.')
if np.isnan(isbiscore):
isbiscore = 0.
return isbiscore
class Head(nn.Sequential):
def __init__(self, inc, outc):
super().__init__()
self.add_module('bn1', nn.InstanceNorm3d(inc, affine=True))
self.add_module('act1', nn.LeakyReLU(inplace=True))
self.add_module('pad', nn.ReplicationPad3d(1))
self.add_module('conv1', nn.Conv3d(inc, inc, 3, bias=False))
self.add_module('bn2', nn.InstanceNorm3d(inc, affine=True))
self.add_module('act2', nn.LeakyReLU(inplace=True))
self.add_module('conv2', nn.Conv3d(inc, outc, 1))
class MSLightningTiramisu(LightningTiramisu):
def __init__(self,
hparams:AttributeDict,
train_subject_list:List[torchio.Subject]=None,
valid_subject_list:List[torchio.Subject]=None):
if self._use_multitask_w_head(hparams):
out_channels = hparams['network_params']['out_channels']
hparams['network_params']['out_channels'] = hparams['misc_params']['head_size']
super().__init__(hparams)
if self._use_multitask_w_head():
hparams['network_params']['out_channels'] = out_channels
self.syn_head = Head(hparams['misc_params']['head_size'], 1)
self.seg_head = Head(hparams['misc_params']['head_size'], 1)
self.train_subject_list = train_subject_list
self.valid_subject_list = valid_subject_list
self.isbiscore = ISBIScore('isbi15_score')
def _use_multitask_w_head(self, hparams=None):
if hparams is None:
hparams = self.hparams
return (hparams['data_params']['use_multitask'] and \
hparams['misc_params']['head_size'] > 0)
@property
def _use_pd(self):
return self.hparams['data_params']['use_pd']
@property
def _use_multitask(self):
return self.hparams['data_params']['use_multitask']
@property
def _use_weight(self):
return self.hparams['data_params']['use_weight']
@property
def _use_mixup(self):
return self.hparams['data_params']['use_mixup']
@property
def _mix_to_one(self):
return self.hparams['data_params']['mix_to_one'] and \
self.hparams['data_params']['batch_size'] == 2
@property
def _mixup_alpha(self):
return self.hparams['data_params']['mixup_alpha']
@property
def _mixup_margin(self):
return self.hparams['data_params']['mixup_margin']
@property
def _use_cutmix(self):
return self.hparams['data_params']['use_cutmix']
@property
def _threshold(self):
return self.hparams['data_params']['threshold']
@property
def _syn_weight(self):
return self.hparams['misc_params']['syn_weight']
@property
def _softmask(self):
return self.hparams['data_params']['softmask']
@property
def _isbi15score_weight(self):
return self.hparams['misc_params']['isbi15score_weight']
def configure_optimizers(self):
if self.hparams['misc_params']['use_rmsprop']:
self.hparams['optim_params']['momentum'] = self.hparams['optim_params']['betas'][0]
self.hparams['optim_params']['alpha'] = self.hparams['optim_params']['betas'][1]
del self.hparams['optim_params']['betas']
optimizer = RMSprop(self.parameters(), **self.hparams['optim_params'])
else:
optimizer = AdamW(self.parameters(), **self.hparams['optim_params'])
n_epochs = self.hparams['lightning_params']['n_epochs']
decay_after = self.hparams['lightning_params']['decay_after']
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch - decay_after) / float(n_epochs + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
return [optimizer], [scheduler]
def train_dataloader(self):
if self.hparams['data_params']['use_aug']:
spatial = OneOf(
{RandomAffine(): 0.8, RandomElasticDeformation(): 0.2},
p=0.75,
)
transforms = [spatial]
transform = Compose(transforms)
subjects_dataset = torchio.SubjectsDataset(
self.train_subject_list, transform=transform)
else:
subjects_dataset = torchio.SubjectsDataset(
self.train_subject_list)
if self.hparams['data_params']['use_label_sampler']:
sampler = torchio.data.LabelSampler(
self.hparams['data_params']["patch_size"])
else:
sampler = torchio.data.UniformSampler(
self.hparams['data_params']["patch_size"])
patches_queue = torchio.Queue(
subjects_dataset,
self.hparams['data_params']["queue_length"],
self.hparams['data_params']["samples_per_volume"],
sampler,
num_workers=self.hparams['data_params']["num_workers"],
shuffle_subjects=True,
shuffle_patches=True)
train_dataloader = DataLoader(
patches_queue,
batch_size=self.hparams['data_params']["batch_size"])
return train_dataloader
def val_dataloader(self):
subjects_dataset = torchio.SubjectsDataset(
self.valid_subject_list,
transform=torchio.CropOrPad(
self.hparams['data_params']['valid_patch_size']))
val_dataloader = DataLoader(
subjects_dataset,
shuffle=True,
batch_size=self.hparams['data_params']['valid_batch_size'])
return val_dataloader
def _collate_batch(self, batch):
fl = batch['flair'][torchio.DATA]
y = batch['label'][torchio.DATA]
if self._use_pd:
pd = batch['pd'][torchio.DATA]
if self._use_weight:
w = batch['weight']
with torch.no_grad():
if self._use_pd:
x = torch.cat((fl,pd),1)
else:
x = fl
if 'div' in batch:
x /= batch['div'].view(-1,1,1,1,1)
out = (x,y,w) if self._use_weight else (x,y)
return out
def _collate_multitask_batch(self, batch):
t1 = batch['t1'][torchio.DATA]
fl = batch['flair'][torchio.DATA]
y = batch['label'][torchio.DATA]
if self._use_pd:
pd = batch['pd'][torchio.DATA]
if self._use_weight:
w = batch['weight']
with torch.no_grad():
if self._use_pd:
x = torch.cat((fl,pd),1)
else:
x = fl
if 'div' in batch:
x /= batch['div'].view(-1,1,1,1,1)
out = (x,y,t1,w) if self._use_weight else (x,y,t1)
return out
def _mixup_beta(self, n):
alpha = self._mixup_alpha
m = torch.distributions.beta.Beta(alpha, alpha)
return m.sample((n,1,1,1,1))
def _asymmetric_mixup_threshold(self, y, b):
if self._mixup_margin > 0.:
choose_orig = b > self._mixup_margin
choose_perm = (1. - b) > self._mixup_margin
mask = (y == (1. - b)) * choose_perm + (y == b) * choose_orig
mask = (y == 1.) | mask
y *= 0.
y[mask] = 1.
return y
def mix(self, x, y, w=None):
with torch.no_grad():
if self._use_mixup:
n = x.size(0)
rp = torch.randperm(n)
b = self._mixup_beta(n).to(x.device)
x_perm = x[rp].clone()
x = b * x + (1 - b) * x_perm
y = y.float()
y_perm = y[rp].clone()
y = b * y + (1 - b) * y_perm
y = self._asymmetric_mixup_threshold(y, b)
if w is not None:
w_perm = w[rp].clone()
w = b * w + (1 - b) * w_perm
if self._use_cutmix:
x, y = cutmix3d((x, y))
if self._mix_to_one:
x, y = x[0:1], y[0:1]
return (x, y) if w is None else (x,y,w)
def multitask_mix(self, x, seg, syn, w=None):
with torch.no_grad():
if self._use_mixup:
n = x.size(0)
rp = torch.randperm(n)
b = self._mixup_beta(n).to(x.device)
x_perm = x[rp].clone()
x = b * x + (1 - b) * x_perm
seg = seg.float()
seg_perm = seg[rp].clone()
seg = b * seg + (1 - b) * seg_perm
syn_perm = syn[rp].clone()
syn = b * syn + (1 - b) * syn_perm
seg = self._asymmetric_mixup_threshold(seg, b)
if w is not None:
w_perm = w[rp].clone()
w = b * w + (1 - b) * w_perm
if self._mix_to_one:
x, seg, syn = x[0:1], seg[0:1], syn[0:1]
out = (x,seg,syn) if w is None else (x,seg,syn,w)
return out
def _training_step(self, batch):
batch = self._collate_batch(batch)
if self._use_weight:
x, y, w = self.mix(*batch)
else:
x, y = self.mix(*batch)
y_hat = self(x)
loss = self.train_criterion(y_hat, y, reduction='none')
if self._use_weight:
mean_dims = tuple(range(1 - len(x.shape), 0))
loss = torch.mean(loss, dim=mean_dims)
loss *= w
loss = loss.mean()
tensorboard_logs = dict(
train_loss=loss)
return {'loss': loss, 'log': tensorboard_logs}
def _multitask_training_step(self, batch):
batch = self._collate_multitask_batch(batch)
if self._use_weight:
x, y, t1, w = self.multitask_mix(*batch)
else:
x, y, t1 = self.multitask_mix(*batch)
if self._use_multitask_w_head():
inter_x = self(x)
y_hat = self.seg_head(inter_x)
t1_hat = self.syn_head(inter_x)
else:
y_hat, t1_hat = torch.chunk(self(x), 2, dim=1)
seg_loss = self.train_criterion(y_hat, y, reduction='none')
syn_loss = self.train_syn_criterion(t1_hat, t1, reduction='none')
loss = seg_loss + self._syn_weight * syn_loss
if self._use_weight:
mean_dims = tuple(range(1 - len(x.shape), 0))
loss = torch.mean(loss, dim=mean_dims)
loss *= w
loss = loss.mean()
seg_loss = seg_loss.mean()
syn_loss = syn_loss.mean()
tensorboard_logs = dict(
train_loss=loss,
train_seg_loss=seg_loss,
train_syn_loss=syn_loss)
return {'loss': loss, 'log': tensorboard_logs}
def training_step(self, batch, batch_idx):
if self._use_multitask:
out_dict = self._multitask_training_step(batch)
else:
out_dict = self._training_step(batch)
return out_dict
def _validation_step(self, batch):
if self._use_weight:
x, y, w = self._collate_batch(batch)
else:
x, y = self._collate_batch(batch)
y_hat = self(x)
loss = self.valid_criterion(y_hat, y, reduction='none')
if self._use_weight:
mean_dims = tuple(range(1 - len(x.shape), 0))
loss = torch.mean(loss, dim=mean_dims)
loss *= w
loss = loss.mean()
with torch.no_grad():
y_hat_ = torch.sigmoid(y_hat) > self._threshold
isbiscore = self.isbiscore(y_hat_, y)
isbiscore = isbiscore.to(loss.device)
out_dict = dict(
val_loss=loss,
val_isbi15_score=isbiscore)
return out_dict, (x, y, y_hat)
def _multitask_validation_step(self, batch):
if self._use_weight:
x, y, t1, w = self._collate_multitask_batch(batch)
else:
x, y, t1 = self._collate_multitask_batch(batch)
if self._use_multitask_w_head():
inter_x = self(x)
y_hat = self.seg_head(inter_x)
t1_hat = self.syn_head(inter_x)
else:
y_hat, t1_hat = torch.chunk(self(x), 2, dim=1)
seg_loss = self.valid_criterion(y_hat, y, reduction='none')
syn_loss = self.valid_syn_criterion(t1_hat, t1, reduction='none')
loss = seg_loss + self._syn_weight * syn_loss
if self._use_weight:
mean_dims = tuple(range(1 - len(x.shape), 0))
loss = torch.mean(loss, dim=mean_dims)
loss *= w
loss = loss.mean()
seg_loss = seg_loss.mean()
syn_loss = syn_loss.mean()
with torch.no_grad():
y_hat_ = torch.sigmoid(y_hat) > self._threshold
isbiscore = self.isbiscore(y_hat_, y)
isbiscore = isbiscore.to(loss.device)
out_dict = dict(
val_loss=loss,
val_isbi15_score=isbiscore,
val_seg_loss=seg_loss,
val_syn_loss=syn_loss)
return out_dict, (x, y, t1, y_hat, t1_hat)
def validation_step(self, batch, batch_idx):
if self._use_multitask:
out_dict, imgs = self._multitask_validation_step(batch)
x, y, t1, y_hat, t1_hat = imgs
if batch_idx == 0:
with torch.no_grad():
mid_slice = x.shape[-1] // 2
fl_ = normalize(x[:,0:1,:,:,mid_slice])
if self._use_pd:
pd_ = normalize(x[:,1:2,:,:,mid_slice])
t1_ = normalize(t1[...,mid_slice])
t1_hat_ = normalize(t1_hat[...,mid_slice])
if self._softmask:
y_hat_ = normalize(torch.sigmoid(y_hat[...,mid_slice]))
y_ = normalize(y[...,mid_slice])
else:
y_hat_ = torch.sigmoid(y_hat[...,mid_slice]) > self._threshold
y_ = y[...,mid_slice] > 0.
n = self.current_epoch
self.logger.experiment.add_images('t1', t1_, n, dataformats='NCHW')
self.logger.experiment.add_images('flair', fl_, n, dataformats='NCHW')
self.logger.experiment.add_images('y', y_, n, dataformats='NCHW')
self.logger.experiment.add_images('t1_hat', t1_hat_, n, dataformats='NCHW')
self.logger.experiment.add_images('y_hat', y_hat_, n, dataformats='NCHW')
if self._use_pd:
self.logger.experiment.add_images('pd', pd_, n, dataformats='NCHW')
else:
out_dict, imgs = self._validation_step(batch)
x, y, y_hat = imgs
if batch_idx == 0:
with torch.no_grad():
mid_slice = x.shape[-1] // 2
t1_ = normalize(x[:,0:1,:,:,mid_slice])
fl_ = normalize(x[:,2:3,:,:,mid_slice])
if self._use_pd:
pd_ = normalize(x[:,3:4,:,:,mid_slice])
if self._softmask:
y_hat_ = normalize(torch.sigmoid(y_hat[...,mid_slice]))
y_ = normalize(y[...,mid_slice])
else:
y_hat_ = torch.sigmoid(y_hat[...,mid_slice]) > self._threshold
y_ = y[...,mid_slice] > 0.
n = self.current_epoch
self.logger.experiment.add_images('t1', t1_, n, dataformats='NCHW')
self.logger.experiment.add_images('flair', fl_, n, dataformats='NCHW')
self.logger.experiment.add_images('pred', y_hat_, n, dataformats='NCHW')
self.logger.experiment.add_images('truth', y_, n, dataformats='NCHW')
if self._use_pd:
self.logger.experiment.add_images('pd', pd_, n, dataformats='NCHW')
return out_dict
@staticmethod
def _cat(x):
try:
x = torch.cat(x)
except RuntimeError:
x = torch.tensor(x)
return x
def validation_epoch_end(self, outputs):
avg_loss = self._cat([x['val_loss'] for x in outputs]).mean()
avg_isbi15_score = self._cat([x['val_isbi15_score'] for x in outputs]).mean()
if self._use_multitask:
avg_seg_loss = self._cat([x['val_seg_loss'] for x in outputs]).mean()
avg_syn_loss = self._cat([x['val_syn_loss'] for x in outputs]).mean()
avg_isbi15_score_minus_loss = self._isbi15score_weight * avg_isbi15_score - avg_seg_loss
tensorboard_logs = {'avg_val_loss': avg_loss,
'avg_val_isbi15_score': avg_isbi15_score,
'avg_val_isbi15_score_minus_loss': avg_isbi15_score_minus_loss,
'avg_val_seg_loss': avg_seg_loss,
'avg_val_syn_loss': avg_syn_loss}
else:
avg_isbi15_score_minus_loss = self._isbi15score_weight * avg_isbi15_score - avg_loss
tensorboard_logs = {'avg_val_loss': avg_loss,
'avg_val_isbi15_score': avg_isbi15_score,
'avg_val_isbi15_score_minus_loss': avg_isbi15_score_minus_loss}
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def process_whole_img(self, sample):
fl_fn = str(sample['flair'].path)
fl = nib.load(fl_fn).get_fdata(dtype=np.float32)
xs = [fl]
if self._use_pd:
pd_fn = str(sample['pd'].path)
pd = nib.load(pd_fn).get_fdata(dtype=np.float32)
xs.append(pd)
out = torch.from_numpy(np.zeros_like(fl))
h1,h2,w1,w2,d1,d2 = bbox3D(fl > 0.)
xs = [x[h1:h2,w1:w2,d1:d2] for x in xs]
x = np.stack(xs)[None,...]
x = torch.from_numpy(x).to(self.device)
self.eval()
with torch.no_grad():
if self._use_multitask:
if self._use_multitask_w_head():
logits = self.seg_head(self(x))
else:
logits, _ = torch.chunk(self(x), 2, dim=1)
else:
logits = self(x)
probits = torch.sigmoid(logits)
out[h1:h2,w1:w2,d1:d2] = probits.detach().cpu()
torch.cuda.empty_cache()
return out
def process_img_patches(self, sample, patch_overlap=None):
patch_size = self.hparams['data_params']['patch_size']
batch_size = self.hparams['data_params']['batch_size']
if patch_overlap is None:
patch_overlap = patch_size // 2
grid_sampler = torchio.inference.GridSampler(
sample,
patch_size,
patch_overlap,
padding_mode='replicate'
)
patch_loader = torch.utils.data.DataLoader(
grid_sampler,
batch_size=batch_size)
aggregator = torchio.inference.GridAggregator(grid_sampler)
self.eval()
with torch.no_grad():
for patches_batch in patch_loader:
fl = patches_batch['flair'][torchio.DATA]
xs = [fl]
if self._use_pd:
pd_fn = str(sample['pd'].path)
pd = nib.load(pd_fn).get_fdata(dtype=np.float32)
xs.append(pd)
x = torch.cat(xs, 1).to(self.device)
locations = patches_batch[torchio.LOCATION]
if self._use_multitask:
if self._use_multitask_w_head():
logits = self.seg_head(self(x))
else:
logits, _ = torch.chunk(self(x), 2, dim=1)
else:
logits = self(x)
probits = torch.sigmoid(logits)
aggregator.add_batch(probits, locations)
out = aggregator.get_output_tensor().detach().cpu()
torch.cuda.empty_cache()
return out
############################# Helper functions #################################
def normalize(x):
dim = x.dim()
xmin, xmax = x.clone(), x.clone()
for i in range(1,dim):
xmin, _ = xmin.min(dim=i, keepdim=True)
xmax, _ = xmax.max(dim=i, keepdim=True)
return (x - xmin) / (xmax - xmin)
def to_np(x):
return x.detach().cpu().numpy()
def split_filename(filepath):
""" split a filepath into the directory, base, and extension """
path = os.path.dirname(filepath)
filename = os.path.basename(filepath)
base, ext = os.path.splitext(filename)
if ext == '.gz':
base, ext2 = os.path.splitext(base)
ext = ext2 + ext
return path, base, ext
def clean_seg(x, bhf=False, mls=4):
if bhf:
structure = generate_binary_structure(3, 3)
x = binary_fill_holes(x, structure=structure)
if mls > 0:
x = remove_small_objects(x, min_size=mls, connectivity=3)
return x
def l1_segmentation_loss(x, y, reduction='mean'):
x = torch.sigmoid(x)
return F.l1_loss(x, y, reduction=reduction)
def mse_segmentation_loss(x, y, reduction='mean'):
x = torch.sigmoid(x)
return F.mse_loss(x, y, reduction=reduction)
def bbox3D(img, offset=5):
r = np.any(img, axis=(1, 2))
c = np.any(img, axis=(0, 2))
z = np.any(img, axis=(0, 1))
rmin, rmax = np.where(r)[0][[0, -1]]
cmin, cmax = np.where(c)[0][[0, -1]]
zmin, zmax = np.where(z)[0][[0, -1]]
i, j, k = img.shape
return (max(rmin-offset,0), min(rmax+offset,i),
max(cmin-offset,0), min(cmax+offset,j),
max(zmin-offset,0), min(zmax+offset,k))
################################### Main #######################################
def main():
parser = arg_parser()
args = parser.parse_args()
if args.verbosity == 1:
level = logging.getLevelName('INFO')
elif args.verbosity >= 2:
level = logging.getLevelName('DEBUG')
else:
level = logging.getLevelName('WARNING')
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__)
seed_everything(args.seed)
exp_config = create_exp_config(args)
model = None
if args.train_csv is not None and args.valid_csv is not None:
train_subject_list = csv_to_subjectlist(args.train_csv)
valid_subject_list = csv_to_subjectlist(args.valid_csv)
model = MSLightningTiramisu(
exp_config,
train_subject_list,
valid_subject_list)
if args.loss_function == 'combo':
model.seg_loss = binary_combo_loss
model.train_criterion = partial(model.seg_loss, weight=args.loss_weight)
model.valid_criterion = partial(model.seg_loss, weight=args.loss_weight)
elif args.loss_function == 'mse':
model.seg_loss = mse_segmentation_loss
model.train_criterion = model.seg_loss
model.valid_criterion = model.seg_loss
elif args.loss_function == 'l1':
model.seg_loss = l1_segmentation_loss
model.train_criterion = model.seg_loss
model.valid_criterion = model.seg_loss
else:
raise ValueError(f'{args.loss_function} not supported')
if args.use_multitask:
model.syn_loss = F.l1_loss
model.train_syn_criterion = model.syn_loss
model.valid_syn_criterion = model.syn_loss
n_epochs = exp_config.lightning_params['n_epochs']
logger.info(model)
gpu_kwargs = dict(gpus=2, distributed_backend='dp') if args.multigpu else \
dict(gpus=[1])
checkpoint_callback = ModelCheckpoint(
monitor='avg_val_isbi15_score_minus_loss',
save_top_k=1,
save_last=True,
mode='max'
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
trainer = Trainer(
benchmark=True,
check_val_every_n_epoch=1,
accumulate_grad_batches=1,
min_epochs=args.n_epochs,
max_epochs=args.n_epochs,
checkpoint_callback=checkpoint_callback,
resume_from_checkpoint=args.resume,
fast_dev_run=False,
**gpu_kwargs
)
torch.cuda.empty_cache()
trainer.fit(model)
if args.test_csv is not None:
device = torch.device('cuda:0')
if model is None:
logger.info(f"Loading: {args.trained_model_path}")
model = MSLightningTiramisu.load_from_checkpoint(
args.trained_model_path
)
model.to(device)
torch.cuda.empty_cache()
test_subject_list = csv_to_subjectlist(args.test_csv)
for test_subj in test_subject_list:
name = test_subj.name
logger.info(f'Processing: {name}')
if args.patch_size == [0,0,0]:
output = model.process_whole_img(test_subj)
else:
output = model.process_img_patches(test_subj)
prob_data = output.numpy().squeeze()
seg_data = prob_data > args.threshold
seg_data = clean_seg(
seg_data,
bhf=args.binary_hole_fill,
mls=args.min_lesion_size).astype(np.float32)
prob_fn = join(args.out_path, name + '_prob.nii.gz')
seg_fn = join(args.out_path, name + '_seg.nii.gz')
in_fn = str(test_subj['flair'].path)
in_nii = nib.load(in_fn)
in_data = in_nii.get_fdata()
assert in_data.shape == prob_data.shape, f"In: {in_data.shape} != Out: {prob_data.shape}"
prob_nii = nib.Nifti1Image(
prob_data,
in_nii.affine,
in_nii.header)
prob_nii.to_filename(prob_fn)
seg_nii = nib.Nifti1Image(
seg_data,
in_nii.affine,
in_nii.header)
seg_nii.to_filename(seg_fn)
if args.use_morph_gac:
logger.info('Starting morphological geodesic active contour')
fl_fn = str(test_subj['flair'].path)
fl = nib.load(fl_fn).get_fdata(dtype=np.float32)
flg = inverse_gaussian_gradient(fl)
seg_gac = morphological_geodesic_active_contour(
flg, 100, init_level_set=seg_data)
seg_gac_fn = join(args.out_path, name + '_seg_gac.nii.gz')
seg_gac_nii = nib.Nifti1Image(
seg_gac,
in_nii.affine,
in_nii.header)
seg_gac_nii.to_filename(seg_gac_fn)
if args.use_morph_acwe:
logger.info('Starting morphological Chan-Vese')
fl_fn = str(test_subj['flair'].path)
fl = nib.load(fl_fn).get_fdata(dtype=np.float32)
seg_acwe = morphological_chan_vese(
fl, 100, init_level_set=seg_data)
seg_acwe_fn = join(args.out_path, name + '_seg_acwe.nii.gz')
seg_acwe_nii = nib.Nifti1Image(
seg_acwe,
in_nii.affine,
in_nii.header)
seg_acwe_nii.to_filename(seg_acwe_fn)
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment