Created
March 6, 2021 17:07
-
-
Save jcreinhold/c9bfeb15ec4f768cd6af20e1662da8fc to your computer and use it in GitHub Desktop.
Neural network (3D Tiramisu) for FLAIR-based T2-lesion segmentation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
3D Tiramisu network for FLAIR-based T2-lesion segmentation | |
This code is unfortunately a huge mess. But, given the CSV files with | |
the appropriate setup, you can run the below command (starting with | |
"python -u ...") to generate the network used to generate the | |
segmentation results in the paper: | |
"A Structural Causal Model for MR Images of Multiple Sclerosis" | |
https://arxiv.org/abs/2103.03158. | |
This script requires: | |
1. msseg (https://github.com/jcreinhold/msseg) | |
2. lesionqc (https://github.com/jcreinhold/lesionqc) | |
and the various other packages listed in the imports. | |
The CSV file is setup with "flair" and "t1" as headers and then | |
the full path to the NIfTI files (to the corresponding FLAIR and T1-w | |
images) are the rows. We used the ISBI 2015 and MICCAI 2016 Challenge | |
Data, as well as some private labeled data, to train the network. | |
python -u tiramisu3d_only_flair.py \ | |
--train-csv csv/all/train_weighted.csv \ | |
--valid-csv csv/all/valid_weighted.csv \ | |
-ic 1 \ | |
--use-multitask \ | |
--use-mixup \ | |
--use-aug \ | |
-db 4 4 4 4 \ | |
-ub 4 4 4 4 \ | |
-gr 12 \ | |
-bnl 4 \ | |
-lr 0.0002 \ | |
-bt 0.8 0.99 \ | |
-bs 7 \ | |
-ps 64 64 64 \ | |
-vbs 7 \ | |
-vps 96 96 96 \ | |
-da 25 \ | |
-ne 100 \ | |
-dr 0.1 \ | |
-wd 0.000001 \ | |
-ma 0.4 \ | |
-mm 0.8 \ | |
-spv 10 \ | |
-v | |
Author: Jacob Reinhold (jacob.reinhold@jhu.edu) | |
Created on: Feb 28, 2021 | |
""" | |
from typing import * | |
from argparse import ArgumentParser | |
import contextlib | |
from functools import partial | |
import logging | |
import os | |
from os.path import join | |
import sys | |
import warnings | |
import nibabel as nib | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.optim import AdamW, RMSprop, lr_scheduler | |
from torch.utils.data import DataLoader | |
from scipy.ndimage.morphology import binary_fill_holes, generate_binary_structure | |
from skimage.morphology import remove_small_objects | |
from skimage.segmentation import ( | |
inverse_gaussian_gradient, | |
morphological_chan_vese, | |
morphological_geodesic_active_contour | |
) | |
from pytorch_lightning import Trainer, seed_everything | |
from pytorch_lightning.metrics.metric import NumpyMetric | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.utilities.parsing import AttributeDict | |
with open(os.devnull, "w") as f: | |
with contextlib.redirect_stdout(f): | |
import torchio | |
from torchio.transforms import ( | |
Compose, | |
CropOrPad, | |
OneOf, | |
RandomAffine, | |
RandomElasticDeformation | |
) | |
from msseg.experiment.lightningtiramisu import LightningTiramisu | |
from msseg.model.tiramisu import Tiramisu3d | |
from msseg.cutmix import cutmix3d | |
from msseg.data import csv_to_subjectlist | |
from msseg.loss import binary_combo_loss | |
from msseg.util import n_dirname | |
from lesionqc import isbi15_score, corr | |
############################## Configuration ################################### | |
def arg_parser(): | |
parser = ArgumentParser(description='train/test a 3D Tiramisu model') | |
required = parser.add_argument_group('Required') | |
required.add_argument('--train-csv', type=str, default=None, | |
help='path to csv with training images') | |
required.add_argument('--valid-csv', type=str, default=None, | |
help='path to csv with validation images') | |
required.add_argument('--test-csv', type=str, default=None, | |
help='path to csv with test images') | |
required.add_argument('--trained-model-path', type=str, default=None, | |
help='path to output the trained model') | |
required.add_argument('--out-path', type=str, default=None, | |
help='path to output the images in testing') | |
options = parser.add_argument_group('Options') | |
options.add_argument('-bs', '--batch-size', type=int, default=2, | |
help='training/test batch size [Default=2]') | |
options.add_argument('-vbs', '--valid-batch-size', type=int, default=2, | |
help='validation batch size [Default=2]') | |
options.add_argument('-hs', '--head-size', type=int, default=48, | |
help='size of head (for multi-task) [Default=48]') | |
options.add_argument('-lw', '--loss-weight', type=float, default=0.6, | |
help='weight of positive class in combo loss [Default=0.6]') | |
options.add_argument('-mg', '--multigpu', action='store_true', default=False, | |
help='use multiple gpus [Default=False]') | |
options.add_argument('-nw', '--num-workers', type=int, default=16, | |
help='number of CPU processors to use [Default=16]') | |
options.add_argument('-ps', '--patch-size', type=int, nargs=3, default=(96,96,96), | |
help='training/test patch size extracted from image [Default=96^3]') | |
options.add_argument('-vps', '--valid-patch-size', type=int, nargs=3, default=(128,128,128), | |
help='validation patch size extracted from image [Default=128^3]') | |
options.add_argument('-ql', '--queue-length', type=int, default=200, | |
help='queue length for torchio sampler [Default=200]') | |
options.add_argument('-rs', '--resume', type=str, default=None, | |
help='resume from this path [Default=None]') | |
options.add_argument('-sd', '--seed', type=int, default=0, | |
help='set seed for reproducibility [Default=0]') | |
options.add_argument('-spv', '--samples-per-volume', type=int, default=10, | |
help='samples per volume for torchio sampler [Default=10]') | |
options.add_argument('-th', '--threshold', type=float, default=0.5, | |
help='prob. threshold for seg [Default=0.5]') | |
options.add_argument('-uls', '--use-label-sampler', action='store_true', | |
help="use label sampler instead of uniform [Default=False]") | |
options.add_argument('-v', '--verbosity', action="count", default=0, | |
help="increase output verbosity (e.g., -vv is more than -v)") | |
train_options = parser.add_argument_group('Training Options') | |
train_options.add_argument('-bt', '--betas', type=float, default=(0.9,0.999), nargs=2, | |
help='adamw momentum parameters (or RMSprop momentum and alpha params)' | |
' [Default=(0.9,0.999)]') | |
train_options.add_argument('-da', '--decay-after', type=int, default=8, | |
help='decay learning rate after this number of epochs [Default=8]') | |
train_options.add_argument('-iw', '--isbi15score-weight', type=float, default=1., | |
help='weight for isbi15 score in isbi15_score_minus_loss' | |
'(1. is equal weighting) [Default=1.]') | |
train_options.add_argument('-lr', '--learning-rate', type=float, default=3e-4, | |
help='learning rate for the optimizer [Default=3e-4]') | |
train_options.add_argument('-lf', '--loss-function', type=str, default='combo', | |
choices=['combo','l1','mse'], | |
help='loss function to train the network [Default=combo]') | |
train_options.add_argument('-ne', '--n-epochs', type=int, default=64, | |
help='number of epochs [Default=64]') | |
train_options.add_argument('-ur', '--use-rmsprop', action='store_true', | |
help="use rmsprop instead of adam [Default=False]") | |
train_options.add_argument('-uw', '--use-weight', action='store_true', | |
help="use the weight field in the csv to weight subjects [Default=False]") | |
train_options.add_argument('-wd', '--weight-decay', type=float, default=1e-5, | |
help="weight decay parameter for adamw [Default=1e-5]") | |
train_options.add_argument('-sm', '--softmask', action='store_true', | |
help="use softmasks for training [Default=False]") | |
train_options.add_argument('--syn-weight', type=float, default=0.1, | |
help='weight of synthesis objective') | |
nn_options = parser.add_argument_group('Neural Network Options') | |
nn_options.add_argument('-ic', '--in-channels', type=int, default=1, | |
help='number of input channels [Default=1]') | |
nn_options.add_argument('-oc', '--out-channels', type=int, default=1, | |
help='number of output channels [Default=1]') | |
nn_options.add_argument('-dr', '--dropout-rate', type=float, default=0.1, | |
help='dropout probability [Default=0.1]') | |
nn_options.add_argument('-psd', '--p-shakedrop', type=float, default=0., | |
help='shakedrop max probability (according to linear decay rule) [Default=0.]') | |
nn_options.add_argument('-in', '--init-type', type=str, default='he_uniform', | |
choices=('normal', 'xavier_normal', 'he_normal', 'he_uniform', 'orthogonal'), | |
help='use this type of initialization for the network [Default=he_uniform]') | |
nn_options.add_argument('-ing', '--init-gain', type=float, default=0.2, | |
help='use this initialization gain for initialization [Default=0.2]') | |
nn_options.add_argument('-db', '--down-blocks', type=int, default=(4,4,4,4,4), nargs='+', | |
help='tiramisu down block specification [Default=(4,4,4,4,4)]') | |
nn_options.add_argument('-ub', '--up-blocks', type=int, default=(4,4,4,4,4), nargs='+', | |
help='tiramisu up block specification [Default=(4,4,4,4,4)]') | |
nn_options.add_argument('-bnl', '--bottleneck-layers', type=int, default=4, | |
help='tiramisu bottleneck specification [Default=4]') | |
nn_options.add_argument('-gr', '--growth-rate', type=int, default=12, | |
help='tiramisu growth rate specification [Default=12]') | |
nn_options.add_argument('-ocfc', '--out-chans-first-conv', type=int, default=48, | |
help='tiramisu output channel size of first conv specification [Default=48]') | |
aug_options = parser.add_argument_group('Data Augmentation Options') | |
aug_options.add_argument('--use-pd', action='store_true', default=False, | |
help='use PD contrast as input [Default=False]') | |
aug_options.add_argument('--use-aug', action='store_true', default=False, | |
help='use data augmentation [Default=False]') | |
aug_options.add_argument('--use-mixup', action='store_true', default=False, | |
help='use mixup [Default=False]') | |
aug_options.add_argument('-ma', '--mixup-alpha', type=float, default=0.4, | |
help='mixup alpha parameter for beta dist. [Default=0.4]') | |
aug_options.add_argument('-mm', '--mixup-margin', type=float, default=0., | |
help='mixup margin for asymmetric mixup (0=>off) [Default=0.]') | |
aug_options.add_argument('--mix-to-one', action='store_true', default=False, | |
help='mixup/cutmix a batch of 2 to 1 [Default=False]') | |
aug_options.add_argument('--use-cutmix', action='store_true', default=False, | |
help='use cutmix [Default=False]') | |
aug_options.add_argument('--use-multitask', action='store_true', default=False, | |
help='use multitask objective [Default=False]') | |
post_options = parser.add_argument_group('Post-processing Options') | |
post_options.add_argument('-mls', '--min-lesion-size', type=int, default=3, | |
help='in testing, removes lesion smaller in voxels than this value [Default=3]') | |
post_options.add_argument('-bhf', '--binary-hole-fill', action='store_true', default=False, | |
help='in testing, preform binary hole filling') | |
post_options.add_argument('--use-morph-acwe', action='store_true', default=False, | |
help='use morphological chan-vese [Default=False]') | |
post_options.add_argument('--use-morph-gac', action='store_true', default=False, | |
help='use morphological geodesic active contour [Default=False]') | |
return parser | |
def create_exp_config(args): | |
exp_config = AttributeDict( | |
data_params=dict( | |
train_csv = args.train_csv, | |
valid_csv = args.valid_csv, | |
test_csv = args.test_csv, | |
trained_model_path = args.trained_model_path, | |
batch_size = args.batch_size, | |
valid_batch_size = args.valid_batch_size, | |
num_workers = args.num_workers, | |
patch_size = args.patch_size, | |
valid_patch_size = args.valid_patch_size, | |
queue_length = args.queue_length, | |
samples_per_volume = args.samples_per_volume, | |
softmask = args.softmask, | |
threshold = args.threshold, | |
use_aug = args.use_aug, | |
use_cutmix = args.use_cutmix, | |
use_label_sampler = args.use_label_sampler, | |
use_pd = args.use_pd, | |
use_mixup = args.use_mixup, | |
mixup_alpha = args.mixup_alpha, | |
mixup_margin = args.mixup_margin, | |
mix_to_one = args.mix_to_one, | |
use_multitask = args.use_multitask, | |
use_weight = args.use_weight, | |
), | |
lightning_params=dict( | |
init_params = dict( | |
init_type=args.init_type, | |
gain=args.init_gain | |
), | |
decay_after = args.decay_after, | |
multigpu = args.multigpu, | |
n_epochs = args.n_epochs, | |
network_dim = 3, | |
resume = args.resume, | |
seed = args.seed | |
), | |
network_params=dict( | |
in_channels = args.in_channels, | |
out_channels = args.out_channels, | |
down_blocks = args.down_blocks, | |
up_blocks = args.up_blocks, | |
bottleneck_layers = args.bottleneck_layers, | |
growth_rate = args.growth_rate, | |
out_chans_first_conv = args.out_chans_first_conv, | |
dropout_rate = args.dropout_rate, | |
p_shakedrop = args.p_shakedrop | |
), | |
optim_params=dict( | |
lr = args.learning_rate, | |
betas = args.betas, | |
weight_decay = args.weight_decay, | |
), | |
misc_params=dict( | |
head_size = args.head_size, | |
isbi15score_weight = args.isbi15score_weight, | |
loss_function = args.loss_function, | |
loss_weight = args.loss_weight, | |
syn_weight = args.syn_weight, | |
use_rmsprop = args.use_rmsprop | |
) | |
) | |
return exp_config | |
################################# Network ###################################### | |
class ISBIScore(NumpyMetric): | |
def forward(self, y_hat, y): | |
y_, y_hat_ = y.squeeze(), y_hat.squeeze() | |
if y_.ndim == 3 and y_hat_.ndim == 3: # batch size 1 | |
isbiscore = isbi15_score(y_hat_, y_) | |
elif y_.ndim == 4 and y_hat_.ndim == 4: | |
isbiscore = 0. | |
for y_i, y_hat_i in zip(y_, y_hat_): | |
isbiscore += isbi15_score(y_hat_i, y_i, reweighted=False) | |
dims = (1,2,3) | |
pred_vols = y_hat_.sum(axis=dims) | |
true_vols = y_.sum(axis=dims) | |
vol_corr = corr(pred_vols, true_vols) | |
isbiscore += vol_corr / 4 | |
isbiscore /= y_.shape[0] | |
else: | |
raise ValueError(f'y ndim={y_.ndim}; y_hat ndim={y_hat_.ndim} not valid.') | |
if np.isnan(isbiscore): | |
isbiscore = 0. | |
return isbiscore | |
class Head(nn.Sequential): | |
def __init__(self, inc, outc): | |
super().__init__() | |
self.add_module('bn1', nn.InstanceNorm3d(inc, affine=True)) | |
self.add_module('act1', nn.LeakyReLU(inplace=True)) | |
self.add_module('pad', nn.ReplicationPad3d(1)) | |
self.add_module('conv1', nn.Conv3d(inc, inc, 3, bias=False)) | |
self.add_module('bn2', nn.InstanceNorm3d(inc, affine=True)) | |
self.add_module('act2', nn.LeakyReLU(inplace=True)) | |
self.add_module('conv2', nn.Conv3d(inc, outc, 1)) | |
class MSLightningTiramisu(LightningTiramisu): | |
def __init__(self, | |
hparams:AttributeDict, | |
train_subject_list:List[torchio.Subject]=None, | |
valid_subject_list:List[torchio.Subject]=None): | |
if self._use_multitask_w_head(hparams): | |
out_channels = hparams['network_params']['out_channels'] | |
hparams['network_params']['out_channels'] = hparams['misc_params']['head_size'] | |
super().__init__(hparams) | |
if self._use_multitask_w_head(): | |
hparams['network_params']['out_channels'] = out_channels | |
self.syn_head = Head(hparams['misc_params']['head_size'], 1) | |
self.seg_head = Head(hparams['misc_params']['head_size'], 1) | |
self.train_subject_list = train_subject_list | |
self.valid_subject_list = valid_subject_list | |
self.isbiscore = ISBIScore('isbi15_score') | |
def _use_multitask_w_head(self, hparams=None): | |
if hparams is None: | |
hparams = self.hparams | |
return (hparams['data_params']['use_multitask'] and \ | |
hparams['misc_params']['head_size'] > 0) | |
@property | |
def _use_pd(self): | |
return self.hparams['data_params']['use_pd'] | |
@property | |
def _use_multitask(self): | |
return self.hparams['data_params']['use_multitask'] | |
@property | |
def _use_weight(self): | |
return self.hparams['data_params']['use_weight'] | |
@property | |
def _use_mixup(self): | |
return self.hparams['data_params']['use_mixup'] | |
@property | |
def _mix_to_one(self): | |
return self.hparams['data_params']['mix_to_one'] and \ | |
self.hparams['data_params']['batch_size'] == 2 | |
@property | |
def _mixup_alpha(self): | |
return self.hparams['data_params']['mixup_alpha'] | |
@property | |
def _mixup_margin(self): | |
return self.hparams['data_params']['mixup_margin'] | |
@property | |
def _use_cutmix(self): | |
return self.hparams['data_params']['use_cutmix'] | |
@property | |
def _threshold(self): | |
return self.hparams['data_params']['threshold'] | |
@property | |
def _syn_weight(self): | |
return self.hparams['misc_params']['syn_weight'] | |
@property | |
def _softmask(self): | |
return self.hparams['data_params']['softmask'] | |
@property | |
def _isbi15score_weight(self): | |
return self.hparams['misc_params']['isbi15score_weight'] | |
def configure_optimizers(self): | |
if self.hparams['misc_params']['use_rmsprop']: | |
self.hparams['optim_params']['momentum'] = self.hparams['optim_params']['betas'][0] | |
self.hparams['optim_params']['alpha'] = self.hparams['optim_params']['betas'][1] | |
del self.hparams['optim_params']['betas'] | |
optimizer = RMSprop(self.parameters(), **self.hparams['optim_params']) | |
else: | |
optimizer = AdamW(self.parameters(), **self.hparams['optim_params']) | |
n_epochs = self.hparams['lightning_params']['n_epochs'] | |
decay_after = self.hparams['lightning_params']['decay_after'] | |
def lambda_rule(epoch): | |
lr_l = 1.0 - max(0, epoch - decay_after) / float(n_epochs + 1) | |
return lr_l | |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) | |
return [optimizer], [scheduler] | |
def train_dataloader(self): | |
if self.hparams['data_params']['use_aug']: | |
spatial = OneOf( | |
{RandomAffine(): 0.8, RandomElasticDeformation(): 0.2}, | |
p=0.75, | |
) | |
transforms = [spatial] | |
transform = Compose(transforms) | |
subjects_dataset = torchio.SubjectsDataset( | |
self.train_subject_list, transform=transform) | |
else: | |
subjects_dataset = torchio.SubjectsDataset( | |
self.train_subject_list) | |
if self.hparams['data_params']['use_label_sampler']: | |
sampler = torchio.data.LabelSampler( | |
self.hparams['data_params']["patch_size"]) | |
else: | |
sampler = torchio.data.UniformSampler( | |
self.hparams['data_params']["patch_size"]) | |
patches_queue = torchio.Queue( | |
subjects_dataset, | |
self.hparams['data_params']["queue_length"], | |
self.hparams['data_params']["samples_per_volume"], | |
sampler, | |
num_workers=self.hparams['data_params']["num_workers"], | |
shuffle_subjects=True, | |
shuffle_patches=True) | |
train_dataloader = DataLoader( | |
patches_queue, | |
batch_size=self.hparams['data_params']["batch_size"]) | |
return train_dataloader | |
def val_dataloader(self): | |
subjects_dataset = torchio.SubjectsDataset( | |
self.valid_subject_list, | |
transform=torchio.CropOrPad( | |
self.hparams['data_params']['valid_patch_size'])) | |
val_dataloader = DataLoader( | |
subjects_dataset, | |
shuffle=True, | |
batch_size=self.hparams['data_params']['valid_batch_size']) | |
return val_dataloader | |
def _collate_batch(self, batch): | |
fl = batch['flair'][torchio.DATA] | |
y = batch['label'][torchio.DATA] | |
if self._use_pd: | |
pd = batch['pd'][torchio.DATA] | |
if self._use_weight: | |
w = batch['weight'] | |
with torch.no_grad(): | |
if self._use_pd: | |
x = torch.cat((fl,pd),1) | |
else: | |
x = fl | |
if 'div' in batch: | |
x /= batch['div'].view(-1,1,1,1,1) | |
out = (x,y,w) if self._use_weight else (x,y) | |
return out | |
def _collate_multitask_batch(self, batch): | |
t1 = batch['t1'][torchio.DATA] | |
fl = batch['flair'][torchio.DATA] | |
y = batch['label'][torchio.DATA] | |
if self._use_pd: | |
pd = batch['pd'][torchio.DATA] | |
if self._use_weight: | |
w = batch['weight'] | |
with torch.no_grad(): | |
if self._use_pd: | |
x = torch.cat((fl,pd),1) | |
else: | |
x = fl | |
if 'div' in batch: | |
x /= batch['div'].view(-1,1,1,1,1) | |
out = (x,y,t1,w) if self._use_weight else (x,y,t1) | |
return out | |
def _mixup_beta(self, n): | |
alpha = self._mixup_alpha | |
m = torch.distributions.beta.Beta(alpha, alpha) | |
return m.sample((n,1,1,1,1)) | |
def _asymmetric_mixup_threshold(self, y, b): | |
if self._mixup_margin > 0.: | |
choose_orig = b > self._mixup_margin | |
choose_perm = (1. - b) > self._mixup_margin | |
mask = (y == (1. - b)) * choose_perm + (y == b) * choose_orig | |
mask = (y == 1.) | mask | |
y *= 0. | |
y[mask] = 1. | |
return y | |
def mix(self, x, y, w=None): | |
with torch.no_grad(): | |
if self._use_mixup: | |
n = x.size(0) | |
rp = torch.randperm(n) | |
b = self._mixup_beta(n).to(x.device) | |
x_perm = x[rp].clone() | |
x = b * x + (1 - b) * x_perm | |
y = y.float() | |
y_perm = y[rp].clone() | |
y = b * y + (1 - b) * y_perm | |
y = self._asymmetric_mixup_threshold(y, b) | |
if w is not None: | |
w_perm = w[rp].clone() | |
w = b * w + (1 - b) * w_perm | |
if self._use_cutmix: | |
x, y = cutmix3d((x, y)) | |
if self._mix_to_one: | |
x, y = x[0:1], y[0:1] | |
return (x, y) if w is None else (x,y,w) | |
def multitask_mix(self, x, seg, syn, w=None): | |
with torch.no_grad(): | |
if self._use_mixup: | |
n = x.size(0) | |
rp = torch.randperm(n) | |
b = self._mixup_beta(n).to(x.device) | |
x_perm = x[rp].clone() | |
x = b * x + (1 - b) * x_perm | |
seg = seg.float() | |
seg_perm = seg[rp].clone() | |
seg = b * seg + (1 - b) * seg_perm | |
syn_perm = syn[rp].clone() | |
syn = b * syn + (1 - b) * syn_perm | |
seg = self._asymmetric_mixup_threshold(seg, b) | |
if w is not None: | |
w_perm = w[rp].clone() | |
w = b * w + (1 - b) * w_perm | |
if self._mix_to_one: | |
x, seg, syn = x[0:1], seg[0:1], syn[0:1] | |
out = (x,seg,syn) if w is None else (x,seg,syn,w) | |
return out | |
def _training_step(self, batch): | |
batch = self._collate_batch(batch) | |
if self._use_weight: | |
x, y, w = self.mix(*batch) | |
else: | |
x, y = self.mix(*batch) | |
y_hat = self(x) | |
loss = self.train_criterion(y_hat, y, reduction='none') | |
if self._use_weight: | |
mean_dims = tuple(range(1 - len(x.shape), 0)) | |
loss = torch.mean(loss, dim=mean_dims) | |
loss *= w | |
loss = loss.mean() | |
tensorboard_logs = dict( | |
train_loss=loss) | |
return {'loss': loss, 'log': tensorboard_logs} | |
def _multitask_training_step(self, batch): | |
batch = self._collate_multitask_batch(batch) | |
if self._use_weight: | |
x, y, t1, w = self.multitask_mix(*batch) | |
else: | |
x, y, t1 = self.multitask_mix(*batch) | |
if self._use_multitask_w_head(): | |
inter_x = self(x) | |
y_hat = self.seg_head(inter_x) | |
t1_hat = self.syn_head(inter_x) | |
else: | |
y_hat, t1_hat = torch.chunk(self(x), 2, dim=1) | |
seg_loss = self.train_criterion(y_hat, y, reduction='none') | |
syn_loss = self.train_syn_criterion(t1_hat, t1, reduction='none') | |
loss = seg_loss + self._syn_weight * syn_loss | |
if self._use_weight: | |
mean_dims = tuple(range(1 - len(x.shape), 0)) | |
loss = torch.mean(loss, dim=mean_dims) | |
loss *= w | |
loss = loss.mean() | |
seg_loss = seg_loss.mean() | |
syn_loss = syn_loss.mean() | |
tensorboard_logs = dict( | |
train_loss=loss, | |
train_seg_loss=seg_loss, | |
train_syn_loss=syn_loss) | |
return {'loss': loss, 'log': tensorboard_logs} | |
def training_step(self, batch, batch_idx): | |
if self._use_multitask: | |
out_dict = self._multitask_training_step(batch) | |
else: | |
out_dict = self._training_step(batch) | |
return out_dict | |
def _validation_step(self, batch): | |
if self._use_weight: | |
x, y, w = self._collate_batch(batch) | |
else: | |
x, y = self._collate_batch(batch) | |
y_hat = self(x) | |
loss = self.valid_criterion(y_hat, y, reduction='none') | |
if self._use_weight: | |
mean_dims = tuple(range(1 - len(x.shape), 0)) | |
loss = torch.mean(loss, dim=mean_dims) | |
loss *= w | |
loss = loss.mean() | |
with torch.no_grad(): | |
y_hat_ = torch.sigmoid(y_hat) > self._threshold | |
isbiscore = self.isbiscore(y_hat_, y) | |
isbiscore = isbiscore.to(loss.device) | |
out_dict = dict( | |
val_loss=loss, | |
val_isbi15_score=isbiscore) | |
return out_dict, (x, y, y_hat) | |
def _multitask_validation_step(self, batch): | |
if self._use_weight: | |
x, y, t1, w = self._collate_multitask_batch(batch) | |
else: | |
x, y, t1 = self._collate_multitask_batch(batch) | |
if self._use_multitask_w_head(): | |
inter_x = self(x) | |
y_hat = self.seg_head(inter_x) | |
t1_hat = self.syn_head(inter_x) | |
else: | |
y_hat, t1_hat = torch.chunk(self(x), 2, dim=1) | |
seg_loss = self.valid_criterion(y_hat, y, reduction='none') | |
syn_loss = self.valid_syn_criterion(t1_hat, t1, reduction='none') | |
loss = seg_loss + self._syn_weight * syn_loss | |
if self._use_weight: | |
mean_dims = tuple(range(1 - len(x.shape), 0)) | |
loss = torch.mean(loss, dim=mean_dims) | |
loss *= w | |
loss = loss.mean() | |
seg_loss = seg_loss.mean() | |
syn_loss = syn_loss.mean() | |
with torch.no_grad(): | |
y_hat_ = torch.sigmoid(y_hat) > self._threshold | |
isbiscore = self.isbiscore(y_hat_, y) | |
isbiscore = isbiscore.to(loss.device) | |
out_dict = dict( | |
val_loss=loss, | |
val_isbi15_score=isbiscore, | |
val_seg_loss=seg_loss, | |
val_syn_loss=syn_loss) | |
return out_dict, (x, y, t1, y_hat, t1_hat) | |
def validation_step(self, batch, batch_idx): | |
if self._use_multitask: | |
out_dict, imgs = self._multitask_validation_step(batch) | |
x, y, t1, y_hat, t1_hat = imgs | |
if batch_idx == 0: | |
with torch.no_grad(): | |
mid_slice = x.shape[-1] // 2 | |
fl_ = normalize(x[:,0:1,:,:,mid_slice]) | |
if self._use_pd: | |
pd_ = normalize(x[:,1:2,:,:,mid_slice]) | |
t1_ = normalize(t1[...,mid_slice]) | |
t1_hat_ = normalize(t1_hat[...,mid_slice]) | |
if self._softmask: | |
y_hat_ = normalize(torch.sigmoid(y_hat[...,mid_slice])) | |
y_ = normalize(y[...,mid_slice]) | |
else: | |
y_hat_ = torch.sigmoid(y_hat[...,mid_slice]) > self._threshold | |
y_ = y[...,mid_slice] > 0. | |
n = self.current_epoch | |
self.logger.experiment.add_images('t1', t1_, n, dataformats='NCHW') | |
self.logger.experiment.add_images('flair', fl_, n, dataformats='NCHW') | |
self.logger.experiment.add_images('y', y_, n, dataformats='NCHW') | |
self.logger.experiment.add_images('t1_hat', t1_hat_, n, dataformats='NCHW') | |
self.logger.experiment.add_images('y_hat', y_hat_, n, dataformats='NCHW') | |
if self._use_pd: | |
self.logger.experiment.add_images('pd', pd_, n, dataformats='NCHW') | |
else: | |
out_dict, imgs = self._validation_step(batch) | |
x, y, y_hat = imgs | |
if batch_idx == 0: | |
with torch.no_grad(): | |
mid_slice = x.shape[-1] // 2 | |
t1_ = normalize(x[:,0:1,:,:,mid_slice]) | |
fl_ = normalize(x[:,2:3,:,:,mid_slice]) | |
if self._use_pd: | |
pd_ = normalize(x[:,3:4,:,:,mid_slice]) | |
if self._softmask: | |
y_hat_ = normalize(torch.sigmoid(y_hat[...,mid_slice])) | |
y_ = normalize(y[...,mid_slice]) | |
else: | |
y_hat_ = torch.sigmoid(y_hat[...,mid_slice]) > self._threshold | |
y_ = y[...,mid_slice] > 0. | |
n = self.current_epoch | |
self.logger.experiment.add_images('t1', t1_, n, dataformats='NCHW') | |
self.logger.experiment.add_images('flair', fl_, n, dataformats='NCHW') | |
self.logger.experiment.add_images('pred', y_hat_, n, dataformats='NCHW') | |
self.logger.experiment.add_images('truth', y_, n, dataformats='NCHW') | |
if self._use_pd: | |
self.logger.experiment.add_images('pd', pd_, n, dataformats='NCHW') | |
return out_dict | |
@staticmethod | |
def _cat(x): | |
try: | |
x = torch.cat(x) | |
except RuntimeError: | |
x = torch.tensor(x) | |
return x | |
def validation_epoch_end(self, outputs): | |
avg_loss = self._cat([x['val_loss'] for x in outputs]).mean() | |
avg_isbi15_score = self._cat([x['val_isbi15_score'] for x in outputs]).mean() | |
if self._use_multitask: | |
avg_seg_loss = self._cat([x['val_seg_loss'] for x in outputs]).mean() | |
avg_syn_loss = self._cat([x['val_syn_loss'] for x in outputs]).mean() | |
avg_isbi15_score_minus_loss = self._isbi15score_weight * avg_isbi15_score - avg_seg_loss | |
tensorboard_logs = {'avg_val_loss': avg_loss, | |
'avg_val_isbi15_score': avg_isbi15_score, | |
'avg_val_isbi15_score_minus_loss': avg_isbi15_score_minus_loss, | |
'avg_val_seg_loss': avg_seg_loss, | |
'avg_val_syn_loss': avg_syn_loss} | |
else: | |
avg_isbi15_score_minus_loss = self._isbi15score_weight * avg_isbi15_score - avg_loss | |
tensorboard_logs = {'avg_val_loss': avg_loss, | |
'avg_val_isbi15_score': avg_isbi15_score, | |
'avg_val_isbi15_score_minus_loss': avg_isbi15_score_minus_loss} | |
return {'val_loss': avg_loss, 'log': tensorboard_logs} | |
def process_whole_img(self, sample): | |
fl_fn = str(sample['flair'].path) | |
fl = nib.load(fl_fn).get_fdata(dtype=np.float32) | |
xs = [fl] | |
if self._use_pd: | |
pd_fn = str(sample['pd'].path) | |
pd = nib.load(pd_fn).get_fdata(dtype=np.float32) | |
xs.append(pd) | |
out = torch.from_numpy(np.zeros_like(fl)) | |
h1,h2,w1,w2,d1,d2 = bbox3D(fl > 0.) | |
xs = [x[h1:h2,w1:w2,d1:d2] for x in xs] | |
x = np.stack(xs)[None,...] | |
x = torch.from_numpy(x).to(self.device) | |
self.eval() | |
with torch.no_grad(): | |
if self._use_multitask: | |
if self._use_multitask_w_head(): | |
logits = self.seg_head(self(x)) | |
else: | |
logits, _ = torch.chunk(self(x), 2, dim=1) | |
else: | |
logits = self(x) | |
probits = torch.sigmoid(logits) | |
out[h1:h2,w1:w2,d1:d2] = probits.detach().cpu() | |
torch.cuda.empty_cache() | |
return out | |
def process_img_patches(self, sample, patch_overlap=None): | |
patch_size = self.hparams['data_params']['patch_size'] | |
batch_size = self.hparams['data_params']['batch_size'] | |
if patch_overlap is None: | |
patch_overlap = patch_size // 2 | |
grid_sampler = torchio.inference.GridSampler( | |
sample, | |
patch_size, | |
patch_overlap, | |
padding_mode='replicate' | |
) | |
patch_loader = torch.utils.data.DataLoader( | |
grid_sampler, | |
batch_size=batch_size) | |
aggregator = torchio.inference.GridAggregator(grid_sampler) | |
self.eval() | |
with torch.no_grad(): | |
for patches_batch in patch_loader: | |
fl = patches_batch['flair'][torchio.DATA] | |
xs = [fl] | |
if self._use_pd: | |
pd_fn = str(sample['pd'].path) | |
pd = nib.load(pd_fn).get_fdata(dtype=np.float32) | |
xs.append(pd) | |
x = torch.cat(xs, 1).to(self.device) | |
locations = patches_batch[torchio.LOCATION] | |
if self._use_multitask: | |
if self._use_multitask_w_head(): | |
logits = self.seg_head(self(x)) | |
else: | |
logits, _ = torch.chunk(self(x), 2, dim=1) | |
else: | |
logits = self(x) | |
probits = torch.sigmoid(logits) | |
aggregator.add_batch(probits, locations) | |
out = aggregator.get_output_tensor().detach().cpu() | |
torch.cuda.empty_cache() | |
return out | |
############################# Helper functions ################################# | |
def normalize(x): | |
dim = x.dim() | |
xmin, xmax = x.clone(), x.clone() | |
for i in range(1,dim): | |
xmin, _ = xmin.min(dim=i, keepdim=True) | |
xmax, _ = xmax.max(dim=i, keepdim=True) | |
return (x - xmin) / (xmax - xmin) | |
def to_np(x): | |
return x.detach().cpu().numpy() | |
def split_filename(filepath): | |
""" split a filepath into the directory, base, and extension """ | |
path = os.path.dirname(filepath) | |
filename = os.path.basename(filepath) | |
base, ext = os.path.splitext(filename) | |
if ext == '.gz': | |
base, ext2 = os.path.splitext(base) | |
ext = ext2 + ext | |
return path, base, ext | |
def clean_seg(x, bhf=False, mls=4): | |
if bhf: | |
structure = generate_binary_structure(3, 3) | |
x = binary_fill_holes(x, structure=structure) | |
if mls > 0: | |
x = remove_small_objects(x, min_size=mls, connectivity=3) | |
return x | |
def l1_segmentation_loss(x, y, reduction='mean'): | |
x = torch.sigmoid(x) | |
return F.l1_loss(x, y, reduction=reduction) | |
def mse_segmentation_loss(x, y, reduction='mean'): | |
x = torch.sigmoid(x) | |
return F.mse_loss(x, y, reduction=reduction) | |
def bbox3D(img, offset=5): | |
r = np.any(img, axis=(1, 2)) | |
c = np.any(img, axis=(0, 2)) | |
z = np.any(img, axis=(0, 1)) | |
rmin, rmax = np.where(r)[0][[0, -1]] | |
cmin, cmax = np.where(c)[0][[0, -1]] | |
zmin, zmax = np.where(z)[0][[0, -1]] | |
i, j, k = img.shape | |
return (max(rmin-offset,0), min(rmax+offset,i), | |
max(cmin-offset,0), min(cmax+offset,j), | |
max(zmin-offset,0), min(zmax+offset,k)) | |
################################### Main ####################################### | |
def main(): | |
parser = arg_parser() | |
args = parser.parse_args() | |
if args.verbosity == 1: | |
level = logging.getLevelName('INFO') | |
elif args.verbosity >= 2: | |
level = logging.getLevelName('DEBUG') | |
else: | |
level = logging.getLevelName('WARNING') | |
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=level) | |
logger = logging.getLogger(__name__) | |
seed_everything(args.seed) | |
exp_config = create_exp_config(args) | |
model = None | |
if args.train_csv is not None and args.valid_csv is not None: | |
train_subject_list = csv_to_subjectlist(args.train_csv) | |
valid_subject_list = csv_to_subjectlist(args.valid_csv) | |
model = MSLightningTiramisu( | |
exp_config, | |
train_subject_list, | |
valid_subject_list) | |
if args.loss_function == 'combo': | |
model.seg_loss = binary_combo_loss | |
model.train_criterion = partial(model.seg_loss, weight=args.loss_weight) | |
model.valid_criterion = partial(model.seg_loss, weight=args.loss_weight) | |
elif args.loss_function == 'mse': | |
model.seg_loss = mse_segmentation_loss | |
model.train_criterion = model.seg_loss | |
model.valid_criterion = model.seg_loss | |
elif args.loss_function == 'l1': | |
model.seg_loss = l1_segmentation_loss | |
model.train_criterion = model.seg_loss | |
model.valid_criterion = model.seg_loss | |
else: | |
raise ValueError(f'{args.loss_function} not supported') | |
if args.use_multitask: | |
model.syn_loss = F.l1_loss | |
model.train_syn_criterion = model.syn_loss | |
model.valid_syn_criterion = model.syn_loss | |
n_epochs = exp_config.lightning_params['n_epochs'] | |
logger.info(model) | |
gpu_kwargs = dict(gpus=2, distributed_backend='dp') if args.multigpu else \ | |
dict(gpus=[1]) | |
checkpoint_callback = ModelCheckpoint( | |
monitor='avg_val_isbi15_score_minus_loss', | |
save_top_k=1, | |
save_last=True, | |
mode='max' | |
) | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
trainer = Trainer( | |
benchmark=True, | |
check_val_every_n_epoch=1, | |
accumulate_grad_batches=1, | |
min_epochs=args.n_epochs, | |
max_epochs=args.n_epochs, | |
checkpoint_callback=checkpoint_callback, | |
resume_from_checkpoint=args.resume, | |
fast_dev_run=False, | |
**gpu_kwargs | |
) | |
torch.cuda.empty_cache() | |
trainer.fit(model) | |
if args.test_csv is not None: | |
device = torch.device('cuda:0') | |
if model is None: | |
logger.info(f"Loading: {args.trained_model_path}") | |
model = MSLightningTiramisu.load_from_checkpoint( | |
args.trained_model_path | |
) | |
model.to(device) | |
torch.cuda.empty_cache() | |
test_subject_list = csv_to_subjectlist(args.test_csv) | |
for test_subj in test_subject_list: | |
name = test_subj.name | |
logger.info(f'Processing: {name}') | |
if args.patch_size == [0,0,0]: | |
output = model.process_whole_img(test_subj) | |
else: | |
output = model.process_img_patches(test_subj) | |
prob_data = output.numpy().squeeze() | |
seg_data = prob_data > args.threshold | |
seg_data = clean_seg( | |
seg_data, | |
bhf=args.binary_hole_fill, | |
mls=args.min_lesion_size).astype(np.float32) | |
prob_fn = join(args.out_path, name + '_prob.nii.gz') | |
seg_fn = join(args.out_path, name + '_seg.nii.gz') | |
in_fn = str(test_subj['flair'].path) | |
in_nii = nib.load(in_fn) | |
in_data = in_nii.get_fdata() | |
assert in_data.shape == prob_data.shape, f"In: {in_data.shape} != Out: {prob_data.shape}" | |
prob_nii = nib.Nifti1Image( | |
prob_data, | |
in_nii.affine, | |
in_nii.header) | |
prob_nii.to_filename(prob_fn) | |
seg_nii = nib.Nifti1Image( | |
seg_data, | |
in_nii.affine, | |
in_nii.header) | |
seg_nii.to_filename(seg_fn) | |
if args.use_morph_gac: | |
logger.info('Starting morphological geodesic active contour') | |
fl_fn = str(test_subj['flair'].path) | |
fl = nib.load(fl_fn).get_fdata(dtype=np.float32) | |
flg = inverse_gaussian_gradient(fl) | |
seg_gac = morphological_geodesic_active_contour( | |
flg, 100, init_level_set=seg_data) | |
seg_gac_fn = join(args.out_path, name + '_seg_gac.nii.gz') | |
seg_gac_nii = nib.Nifti1Image( | |
seg_gac, | |
in_nii.affine, | |
in_nii.header) | |
seg_gac_nii.to_filename(seg_gac_fn) | |
if args.use_morph_acwe: | |
logger.info('Starting morphological Chan-Vese') | |
fl_fn = str(test_subj['flair'].path) | |
fl = nib.load(fl_fn).get_fdata(dtype=np.float32) | |
seg_acwe = morphological_chan_vese( | |
fl, 100, init_level_set=seg_data) | |
seg_acwe_fn = join(args.out_path, name + '_seg_acwe.nii.gz') | |
seg_acwe_nii = nib.Nifti1Image( | |
seg_acwe, | |
in_nii.affine, | |
in_nii.header) | |
seg_acwe_nii.to_filename(seg_acwe_fn) | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment