Skip to content

Instantly share code, notes, and snippets.

@jcreinhold
Created December 9, 2021 16:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jcreinhold/cd50852ec9b9c417daab7501194893b9 to your computer and use it in GitHub Desktop.
Save jcreinhold/cd50852ec9b9c417daab7501194893b9 to your computer and use it in GitHub Desktop.
normalize ct images in pytorch
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Normalize the intensity of a CT image
Author: Jacob Reinhold
"""
import sys
from argparse import ArgumentParser
from pathlib import Path
from typing import Tuple, Union
import torch
import torchio as tio
def split_filename(filepath: Union[Path, str]) -> Tuple[Path, str, str]:
filepath = Path(filepath)
path = filepath.parent
_base = Path(filepath.stem)
ext = filepath.suffix
if ext == ".gz":
ext2 = _base.suffix
base = str(_base.stem)
ext = ext2 + ext
else:
base = str(_base)
return Path(path), base, ext
def normalize_ct(
ct_tensor: torch.Tensor,
*,
linear_min: float = -100.0,
linear_max: float = 300.0,
scale: float = 0.1,
replace_background: bool = True,
linear_to_01: bool = True,
prop_background: float = 0.05,
) -> torch.Tensor:
"""set background of CT image to min val. in foreground and squash intensities"""
if replace_background:
unq = torch.unique(ct_tensor)
n = int(prop_background * len(unq))
min_val_fg = unq[torch.argmax(torch.diff(unq[:n])) + 1]
ct_tensor[
ct_tensor < min_val_fg
] = min_val_fg # set backgrnd to min val. in foregrnd
conds = [
ct_tensor < linear_min,
(ct_tensor >= linear_min) & (ct_tensor < linear_max),
ct_tensor >= linear_max,
]
funcs = [
lambda x: scale * (x - linear_min) + linear_min,
lambda x: x,
lambda x: scale * (x - linear_max) + linear_max,
]
normalized = torch.zeros_like(ct_tensor)
for mask, f in zip(conds, funcs):
normalized[mask] = f(ct_tensor[mask])
if linear_to_01:
normalized -= linear_min
normalized /= linear_max - linear_min
return normalized
def main() -> int:
parser = ArgumentParser(description="Normalize CT image for image processing")
parser.add_argument("image_path", type=str)
parser.add_argument("-o", "--output-path", type=str, default=None)
parser.add_argument("-ot", "--output-type", type=str, default=".nii")
parser.add_argument("--linear-min", type=float, default=-100.0)
parser.add_argument("--linear-max", type=float, default=300.0)
parser.add_argument("--outside-linear-scale", type=float, default=0.1)
parser.add_argument("--no-linear-to-01", action="store_true")
parser.add_argument("--no-replace-background", action="store_true")
parser.add_argument("--prop-background", type=float, default=0.05)
parser.add_argument("--no-make-out-dir", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
if args.verbose:
print(f"Normalizing image: {args.image_path}")
image = tio.ScalarImage(args.image_path)
normalized = normalize_ct(
image.data.float(),
linear_min=args.linear_min,
linear_max=args.linear_max,
scale=args.outside_linear_scale,
replace_background=not args.no_replace_background,
linear_to_01=not args.no_linear_to_01,
prop_background=args.prop_background,
)
image.set_data(normalized)
if args.output_path is None:
root, base, _ = split_filename(args.image_path)
args.output_path = root / (base + args.output_type)
if args.verbose:
print(f"Saving normalized image: {args.output_path}")
if not args.no_make_out_dir:
Path(args.output_path).parent.mkdir(parents=True, exist_ok=True)
image.save(args.output_path)
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment