Skip to content

Instantly share code, notes, and snippets.

@vikashg
Last active September 19, 2023 13:55
Show Gist options
  • Save vikashg/54c60d41a750e2fcaebff127d013f167 to your computer and use it in GitHub Desktop.
Save vikashg/54c60d41a750e2fcaebff127d013f167 to your computer and use it in GitHub Desktop.
A gist for debugging GeneralizeDiceScore
import logging
import os
from monai.metrics import DiceMetric, GeneralizedDiceScore
from monai.losses import GeneralizedDiceFocalLoss
import json
import sys
from monai.visualize import plot_2d_or_3d_image
import tempfile
from glob import glob
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torch
from monai.data import list_data_collate, decollate_batch, DataLoader
from monai.inferers import SimpleInferer
from monai.transforms import (
AsDiscrete, Activations, Transpose, Resized, EnsureChannelFirstd,
Compose, CropForegroundd, LoadImaged, Orientationd, RandFlipd, RandCropByPosNegLabeld,
RandShiftIntensityd, ScaleIntensityRanged, Spacingd, RandRotate90d )
import monai
from model_def import ModelDefinition
def main():
model_name="SegResNet"
filename = './datalist.json'
out_dir = './model_generalized/' + model_name
batch_size=2
num_epochs = 2500
val_interval = 1
if os.path.exists(out_dir) == 0:
os.makedirs(out_dir)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Read the data
fid = open(filename, 'r')
data_dict = json.load(fid)
train_files = data_dict["train"]
val_files= data_dict["valid"]
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
train_transforms = Compose(
[ LoadImaged(keys=["image", "mask"]),
EnsureChannelFirstd(keys=["image", "mask"]),
Resized(keys=["image", "mask"], spatial_size=(256, 256, 24)),
ScaleIntensityRanged(keys="image",a_min=20, a_max=1200, b_min=0, b_max=1, clip=True),
RandFlipd(keys=["image", "mask"], spatial_axis=0),
RandFlipd(keys=["image", "mask"], spatial_axis=1),
RandFlipd(keys=["image", "mask"], spatial_axis=2),
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5) ])
validation_transforms = Compose([
LoadImaged(keys=["image", "mask"]),
EnsureChannelFirstd(keys=["image", "mask"]),
ScaleIntensityRanged(keys="image",a_min=20, a_max=1200, b_min=0, b_max=1, clip=True),
Resized(keys=["image", "mask"], spatial_size=(256, 256, 24))])
train_ds = monai.data.Dataset(data = train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn = list_data_collate)
val_ds = monai.data.Dataset(data=val_files, transform=validation_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4, collate_fn=list_data_collate)
model_def = ModelDefinition(model_name)
model = model_def.get_model()
model.to(device)
model.train()
epoch_loss = 0
step = 0
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
writer = SummaryWriter(log_dir=out_dir)
dice_metric = DiceMetric(include_background=True, reduction="mean")
loss_fn = GeneralizedDiceFocalLoss(include_background=False, sigmoid=True)
dice_metric = GeneralizedDiceScore(include_background=True, reduction="mean_batch")
for epoch in range(num_epochs):
print("-" *10)
print("epoch {}/{}".format(epoch +1, num_epochs))
model.train()
epoch_loss = 0
step = 0
for batch in tqdm(train_loader):
step += 1
x, y = batch["image"].to(device), batch["mask"].to(device)
optimizer.zero_grad()
outputs = model(x)
loss = loss_fn(outputs, y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_ds) // train_loader.batch_size
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_image = val_data["image"].to(device)
val_labels = val_data["mask"].to(device)
_val_outputs = model(val_image)
val_outputs = post_trans(_val_outputs)
dice_metric(y_pred = val_outputs, y = val_labels)
metric = dice_metric.aggregate().item()
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(out_dir, 'best_metric_model.pth'))
print("current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(epoch + 1, metric, best_metric, best_metric_epoch))
writer.add_scalar("val_mean_dice", metric, epoch+1)
from monai.visualize.utils import blend_images
plot_2d_or_3d_image(Transpose((0, 1, 4, 3, 2))(val_image), epoch+1, writer, index=0, tag="image")
plot_2d_or_3d_image(Transpose((0, 1, 4, 3, 2))(val_labels), epoch+1, writer, index=0, tag="image")
plot_2d_or_3d_image(Transpose((0, 1, 4, 3, 2))(val_outputs), epoch+1, writer, index=0, tag="image")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment