Skip to content

Instantly share code, notes, and snippets.

@vikashg
Created September 19, 2023 17:12
Show Gist options
  • Save vikashg/0f2131aec1b18eef20073a8ec59df894 to your computer and use it in GitHub Desktop.
Save vikashg/0f2131aec1b18eef20073a8ec59df894 to your computer and use it in GitHub Desktop.
import logging
import os
from monai.metrics import DiceMetric, GeneralizedDiceScore
from monai.losses import GeneralizedDiceFocalLoss
import json
import sys
from monai.visualize import plot_2d_or_3d_image
import tempfile
from glob import glob
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torch
from monai.data import list_data_collate, decollate_batch, DataLoader
from monai.inferers import SimpleInferer
from monai.transforms import (
AsDiscrete, Activations, Transpose, Resized, EnsureChannelFirstd,
Compose, CropForegroundd, LoadImaged, Orientationd, RandFlipd, RandCropByPosNegLabeld,
RandShiftIntensityd, ScaleIntensityRanged, Spacingd, RandRotate90d, AddChanneld )
import monai
model = monai.networks.nets.SegResNet(
spatial_dims =3,
blocks_down = [1, 2, 2, 4],
blocks_up = [1, 1, 1],
init_filters = 16,
in_channels = 1,
out_channels=1,
dropout_prob=0.2)
validation_transforms = Compose([ LoadImaged(keys=["image", "mask"]),
EnsureChannelFirstd(keys=["image", "mask"]),
Resized(keys=["image", "mask"], spatial_size=(256, 256, 24)),
ScaleIntensityRanged(keys="image",a_min=20, a_max=1200, b_min=0, b_max=1, clip=True),
RandFlipd(keys=["image", "mask"], spatial_axis=0),
RandFlipd(keys=["image", "mask"], spatial_axis=1),
RandFlipd(keys=["image", "mask"], spatial_axis=2),
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
])
dice_metric = GeneralizedDiceScore(include_background=False, reduction="mean_batch")
_file={'image': './image.nii.gz', 'mask': './segmentation.nii.gz'}
val_list = [_file]*10 # Just generating a data list
batch_size=2
val_ds = monai.data.Dataset(data=val_list, transform=validation_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
model.to(device)
model.eval()
for _val in val_loader:
val_image = _val['image'].to(device)
val_mask = _val['mask'].to(device)
_val_output = model(val_image)
val_out = post_trans(_val_output)
print(val_out.shape)
dice_metric(y_pred=val_out, y=val_mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment