Skip to content

Instantly share code, notes, and snippets.

@jphdotam
Last active July 9, 2024 10:10
Show Gist options
  • Save jphdotam/b4b6b582cf75f7c873b1891627d45eb3 to your computer and use it in GitHub Desktop.
Save jphdotam/b4b6b582cf75f7c873b1891627d45eb3 to your computer and use it in GitHub Desktop.
import os
import math
from glob import glob
import pydicom
import numpy as np
import onnxruntime
from loguru import logger
from matplotlib import pyplot as plt
from skimage.transform import resize
MAG_DIR = r"C:\Users\James\Desktop\flow_maps_perfusion_1_\series0086-Body"
MAGN_DIR = r"C:\Users\James\Desktop\flow_maps_perfusion_1_\series0087-Body"
PHASE_DIR = r"C:\Users\James\Desktop\flow_maps_perfusion_1_\series0088-Body"
ONNX_MODEL_PATH = "../deploy/models/024.yaml__1lujhsy3__epoch=39__loss_val=0.004__iou_val=0.879.ckpt_cuda.onnx"
def pad_video_to_square(video, n_channels=None):
n_frames, h_orig, w_orig = video.shape[:3]
if h_orig == w_orig:
return video
new_dim = max(h_orig, w_orig)
if len(video.shape) == 4:
assert n_channels is not None
new_video = np.zeros((n_frames, new_dim, new_dim, n_channels), dtype=video.dtype)
else:
new_video = np.zeros((n_frames, new_dim, new_dim), dtype=video.dtype)
h_from = 0 if h_orig > w_orig else (new_dim // 2 - h_orig // 2)
w_from = 0 if h_orig < w_orig else (new_dim // 2 - w_orig // 2)
new_video[:, h_from:h_from + h_orig, w_from:w_from + w_orig] = video
return new_video
def depad_video_from_square(video, h_orig, w_orig, n_channels=None):
n_frames, new_h, new_w, *rest = video.shape
assert new_h == new_w, f"Expected square video, got {video.shape}"
new_dim = new_h
if len(rest) == 1:
assert n_channels is not None and rest[0] == n_channels
h_from = 0 if h_orig > w_orig else (new_dim // 2 - h_orig // 2)
w_from = 0 if h_orig < w_orig else (new_dim // 2 - w_orig // 2)
if len(rest) == 1:
depadded_video = video[:, h_from:h_from + h_orig, w_from:w_from + w_orig, :]
else:
depadded_video = video[:, h_from:h_from + h_orig, w_from:w_from + w_orig]
return depadded_video
def load_and_normalise_dicoms(*args, normalise=True):
"""Loops over all positional args and loads and normalises all the stacks within each series"""
out = []
for i, dicom_paths in enumerate(args):
dcms = [pydicom.dcmread(dicom_path) for dicom_path in dicom_paths]
imgs = np.array([dcm.pixel_array for dcm in dcms]).astype(np.float32)
# rescale 0 - 1
if normalise:
imgs -= imgs.min()
imgs /= imgs.max()
out.append(imgs)
return out
def mag_magn_phase_to_predicted_mask(mag_magn_phase: np.ndarray,
ort_session: onnxruntime.InferenceSession,
inference_dim=(320,320)):
# store original H/W
n_frames = len(mag_magn_phase[0])
orig_h, orig_w = mag_magn_phase.shape[-2:]
orig_hw = max(orig_h, orig_w)
# centre pad
logger.debug(f"Pre padding: {mag_magn_phase.shape=}")
mag_magn_phase = np.array([pad_video_to_square(series) for series in mag_magn_phase])
# resize
logger.debug(f"Pre resize: {mag_magn_phase.shape=}")
mag_magn_phase = np.array([[resize(frame, inference_dim) for frame in series] for series in mag_magn_phase])
# shape check
logger.debug(f"Post resize: {mag_magn_phase.shape=}")
assert mag_magn_phase.shape == (3, n_frames, *inference_dim), f"Expected shape {(3, n_frames, *inference_dim)}, got {mag_magn_phase.shape}"
# prepare batch - (mag, magn, phase) == 3 * N_FRAMES * H * W -> N_FRAMES * 3 * H * W
x = np.transpose(mag_magn_phase, (1, 0, 2, 3))
# forward pass
logger.debug(f"Pre forward pass: {x.shape=} {x.min()=} {x.max()=} {x.mean()=} {x.std()=}")
for i_input_channel in range(3):
input_channel = x[:, i_input_channel]
logger.debug(f"\tChannel {i_input_channel} min={input_channel.min()} max={input_channel.max()} mean={input_channel.mean()} std={input_channel.std()}")
pred_logit = ort_session.run(None, {'input': x})[0].transpose(1, 0, 2, 3) # N_CLASSES * N_FRAMES * H * W
# de-resize - let's do this before argmax so we dont get anti-aliasing issues etc.
logger.debug(f"Pre de-resize: {pred_logit.shape=}")
pred_logit = np.array([[resize(frame, (orig_hw, orig_hw)) for frame in channel] for channel in pred_logit])
# de-pad
logger.debug(f"Pre de-pad: {pred_logit.shape=}")
pred_logit = np.array([depad_video_from_square(channel, orig_h, orig_w) for channel in pred_logit])
# argmax
logger.debug(f"Pre argmax: {pred_logit.shape=}")
pred_cls = np.argmax(pred_logit, axis=0) # N_FRAMES * H * W
assert pred_cls.shape == (n_frames, orig_h, orig_w), f"Expected shape {n_frames, orig_h, orig_w}, got {pred_cls.shape=}"
logger.debug(f"Post argmax: {pred_cls.shape=}")
return pred_cls
dicom_paths_mag = sorted(glob(os.path.join(MAG_DIR, '*.dcm')))
dicom_paths_magn = sorted(glob(os.path.join(MAGN_DIR, '*.dcm')))
dicom_paths_phase = sorted(glob(os.path.join(PHASE_DIR, '*.dcm')))
mag, magn, phase = load_and_normalise_dicoms(dicom_paths_mag, dicom_paths_magn, dicom_paths_phase)
mag_magn_phase = np.array([mag, np.zeros_like(magn), phase])
ort_session = onnxruntime.InferenceSession(ONNX_MODEL_PATH)
pred_cls = mag_magn_phase_to_predicted_mask(mag_magn_phase, ort_session)
# create a grid of images and plot pred cls for each frame in each
n_frames = len(mag_magn_phase[0])
n_sqrt = math.ceil(np.sqrt(n_frames))
fig, axs = plt.subplots(n_sqrt, n_sqrt, figsize=(20, 20))
for i, ax in enumerate(axs.flatten()):
if i < n_frames:
ax.imshow(pred_cls[i], cmap='gray')
ax.set_title(f"Frame {i}")
else:
ax.axis('off')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment