Skip to content

Instantly share code, notes, and snippets.

@neeleshbisht99
Created June 7, 2024 02:44
Show Gist options
  • Save neeleshbisht99/8bdc320fc49c69f3b90af06c9ceab2fa to your computer and use it in GitHub Desktop.
Save neeleshbisht99/8bdc320fc49c69f3b90af06c9ceab2fa to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "99WdgO6AMrqq"
},
"source": [
"# 3D image classification from CT scans\n",
"\n",
"**Author:** []()<br>\n",
"**Date created:** 2020/09/23<br>\n",
"**Last modified:** 2024/01/11<br>\n",
"**Description:** Train a 3D convolutional neural network to predict presence of pneumonia."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Skp-bvaEMrqt"
},
"source": [
"## Introduction\n",
"\n",
"This example will show the steps needed to build a 3D convolutional neural network (CNN)\n",
"to predict the presence of viral pneumonia in computer tomography (CT) scans. 2D CNNs are\n",
"commonly used to process RGB images (3 channels). A 3D CNN is simply the 3D\n",
"equivalent: it takes as input a 3D volume or a sequence of 2D frames (e.g. slices in a CT scan),\n",
"3D CNNs are a powerful model for learning representations for volumetric data.\n",
"\n",
"## References\n",
"\n",
"- [A survey on Deep Learning Advances on Different 3D DataRepresentations](https://arxiv.org/abs/1808.01462)\n",
"- [VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition](https://www.ri.cmu.edu/pub_files/2015/9/voxnet_maturana_scherer_iros15.pdf)\n",
"- [FusionNet: 3D Object Classification Using MultipleData Representations](https://arxiv.org/abs/1607.05695)\n",
"- [Uniformizing Techniques to Process CT scans with 3D CNNs for Tuberculosis Prediction](https://arxiv.org/abs/2007.13224)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uTNm8DT6Mrqv"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-06T23:48:58.844723Z",
"iopub.status.busy": "2024-06-06T23:48:58.844193Z",
"iopub.status.idle": "2024-06-06T23:49:07.863602Z",
"shell.execute_reply": "2024-06-06T23:49:07.862735Z",
"shell.execute_reply.started": "2024-06-06T23:48:58.844697Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: nibabel in /usr/local/lib/python3.11/dist-packages (5.2.1)\n",
"Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.11/dist-packages (from nibabel) (1.26.3)\n",
"Requirement already satisfied: packaging>=17 in /usr/local/lib/python3.11/dist-packages (from nibabel) (23.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: SimpleITK in /usr/local/lib/python3.11/dist-packages (2.3.1)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: grad-cam in /usr/local/lib/python3.11/dist-packages (1.5.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from grad-cam) (1.26.3)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from grad-cam) (9.5.0)\n",
"Requirement already satisfied: torch>=1.7.1 in /usr/local/lib/python3.11/dist-packages (from grad-cam) (2.1.1+cu121)\n",
"Requirement already satisfied: torchvision>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from grad-cam) (0.16.1+cu121)\n",
"Requirement already satisfied: ttach in /usr/local/lib/python3.11/dist-packages (from grad-cam) (0.0.3)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from grad-cam) (4.66.1)\n",
"Requirement already satisfied: opencv-python in /usr/local/lib/python3.11/dist-packages (from grad-cam) (4.8.0.76)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (from grad-cam) (3.7.3)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (from grad-cam) (1.3.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (3.13.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (4.9.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (2023.6.0)\n",
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (2.1.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from torchvision>=0.8.2->grad-cam) (2.31.0)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (1.2.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (4.47.2)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (1.4.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (23.2)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib->grad-cam) (2.4.7)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (2.8.2)\n",
"Requirement already satisfied: scipy>=1.5.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->grad-cam) (1.11.2)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->grad-cam) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->grad-cam) (3.2.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib->grad-cam) (1.16.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.7.1->grad-cam) (2.1.4)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->torchvision>=0.8.2->grad-cam) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->torchvision>=0.8.2->grad-cam) (3.3)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->torchvision>=0.8.2->grad-cam) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->torchvision>=0.8.2->grad-cam) (2020.6.20)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.11/dist-packages (from sympy->torch>=1.7.1->grad-cam) (1.3.0)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"!pip install nibabel\n",
"!pip install SimpleITK\n",
"!pip install grad-cam"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-06T23:23:48.643328Z",
"iopub.status.busy": "2024-06-06T23:23:48.642588Z",
"iopub.status.idle": "2024-06-06T23:23:48.647494Z",
"shell.execute_reply": "2024-06-06T23:23:48.647036Z",
"shell.execute_reply.started": "2024-06-06T23:23:48.643290Z"
},
"id": "gbalBhvpMrqv"
},
"outputs": [],
"source": [
"import os\n",
"import zipfile\n",
"import random\n",
"import numpy as np\n",
"import nibabel as nib\n",
"import matplotlib.pyplot as plt\n",
"from scipy import ndimage\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torchvision.transforms import Compose, ToTensor, Lambda\n",
"from torch.optim.lr_scheduler import ExponentialLR"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FTwYU8fqMrqw"
},
"source": [
"## Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings\n",
"\n",
"In this example, we use a subset of the\n",
"[MosMedData: Chest CT Scans with COVID-19 Related Findings](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1).\n",
"This dataset consists of lung CT scans with COVID-19 related findings, as well as without such findings.\n",
"\n",
"We will be using the associated radiological findings of the CT scans as labels to build\n",
"a classifier to predict presence of viral pneumonia.\n",
"Hence, the task is a binary classification problem."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-06T23:23:48.648479Z",
"iopub.status.busy": "2024-06-06T23:23:48.648124Z",
"iopub.status.idle": "2024-06-06T23:23:48.653068Z",
"shell.execute_reply": "2024-06-06T23:23:48.652564Z",
"shell.execute_reply.started": "2024-06-06T23:23:48.648458Z"
},
"id": "Wq2u7jgcMrqx"
},
"outputs": [],
"source": [
"import urllib.request\n",
"\n",
"def download_and_extract(url, output_dir):\n",
" zip_path = os.path.join(output_dir, os.path.basename(url))\n",
" urllib.request.urlretrieve(url, zip_path)\n",
" with zipfile.ZipFile(zip_path, 'r') as z_fp:\n",
" z_fp.extractall(output_dir)\n",
"\n",
"# Create directories\n",
"output_dir = os.path.join(os.getcwd(), \"MosMedData\")\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"# Download and extract data\n",
"# download_and_extract(\"https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip\", output_dir)\n",
"# download_and_extract(\"https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip\", output_dir)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "phaGcd73Mrqx"
},
"source": [
"## Loading data and preprocessing\n",
"\n",
"The files are provided in Nifti format with the extension .nii. To read the\n",
"scans, we use the `nibabel` package.\n",
"You can install the package via `pip install nibabel`. CT scans store raw voxel\n",
"intensity in Hounsfield units (HU). They range from -1024 to above 2000 in this dataset.\n",
"Above 400 are bones with different radiointensity, so this is used as a higher bound. A threshold\n",
"between -1000 and 400 is commonly used to normalize CT scans.\n",
"\n",
"To process the data, we do the following:\n",
"\n",
"* We first rotate the volumes by 90 degrees, so the orientation is fixed\n",
"* We scale the HU values to be between 0 and 1.\n",
"* We resize width, height and depth.\n",
"\n",
"Here we define several helper functions to process the data. These functions\n",
"will be used when building training and validation datasets."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-06T23:23:48.654319Z",
"iopub.status.busy": "2024-06-06T23:23:48.653888Z",
"iopub.status.idle": "2024-06-06T23:23:48.659862Z",
"shell.execute_reply": "2024-06-06T23:23:48.659196Z",
"shell.execute_reply.started": "2024-06-06T23:23:48.654296Z"
},
"id": "28MnrIoLMrqy"
},
"outputs": [],
"source": [
"\n",
"import nibabel as nib\n",
"\n",
"from scipy import ndimage\n",
"\n",
"def read_nifti_file(filepath):\n",
" scan = nib.load(filepath)\n",
" scan = scan.get_fdata()\n",
" return scan\n",
"\n",
"def normalize(volume):\n",
" min, max = -1000, 400\n",
" volume[volume < min] = min\n",
" volume[volume > max] = max\n",
" volume = (volume - min) / (max - min)\n",
" volume = volume.astype(\"float32\")\n",
" return volume\n",
"\n",
"def resize_volume(img):\n",
" desired_depth, desired_width, desired_height = 64, 128, 128\n",
" current_depth, current_width, current_height = img.shape[-1], img.shape[0], img.shape[1]\n",
" depth_factor, width_factor, height_factor = current_depth / desired_depth, current_width / desired_width, current_height / desired_height\n",
" depth_factor, width_factor, height_factor = 1 / depth_factor, 1 / width_factor, 1 / height_factor\n",
" img = ndimage.rotate(img, 90, reshape=False)\n",
" img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)\n",
" return img\n",
"\n",
"def process_scan(path):\n",
" volume = read_nifti_file(path)\n",
" volume = normalize(volume)\n",
" volume = resize_volume(volume)\n",
" return volume\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6JlnAmXlMrqz"
},
"source": [
"Let's read the paths of the CT scans from the class directories."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"execution": {
"iopub.execute_input": "2024-06-06T23:23:48.661962Z",
"iopub.status.busy": "2024-06-06T23:23:48.661548Z",
"iopub.status.idle": "2024-06-06T23:31:56.578684Z",
"shell.execute_reply": "2024-06-06T23:31:56.577871Z",
"shell.execute_reply.started": "2024-06-06T23:23:48.661942Z"
},
"id": "al541TsPMrqz",
"outputId": "570542d8-c00e-434d-88bb-2396f5c19b65"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CT scans with normal lung tissue: 100\n",
"CT scans with abnormal lung tissue: 100\n"
]
}
],
"source": [
"class CTScanDataset(Dataset):\n",
" def __init__(self, scan_paths, labels, transform=None):\n",
" self.scan_paths = scan_paths\n",
" self.labels = labels\n",
" self.transform = transform\n",
"\n",
" def __len__(self):\n",
" return len(self.scan_paths)\n",
"\n",
" def __getitem__(self, idx):\n",
" scan = self.scan_paths[idx]\n",
" label = self.labels[idx]\n",
" if self.transform:\n",
" scan = self.transform(scan)\n",
" return scan, label\n",
"\n",
"# Prepare data\n",
"normal_scan_paths = [os.path.join(output_dir, \"CT-0\", x) for x in os.listdir(os.path.join(output_dir, \"CT-0\"))]\n",
"abnormal_scan_paths = [os.path.join(output_dir, \"CT-23\", x) for x in os.listdir(os.path.join(output_dir, \"CT-23\"))]\n",
"\n",
"print(\"CT scans with normal lung tissue: \" + str(len(normal_scan_paths)))\n",
"print(\"CT scans with abnormal lung tissue: \" + str(len(abnormal_scan_paths)))\n",
"\n",
"# Process scans\n",
"abnormal_scans = [process_scan(path) for path in abnormal_scan_paths]\n",
"normal_scans = [process_scan(path) for path in normal_scan_paths]\n",
"\n",
"abnormal_labels = [1 for _ in range(len(abnormal_scans))]\n",
"normal_labels = [0 for _ in range(len(normal_scans))]\n",
"\n",
"# Split data\n",
"x_train = abnormal_scans[:70] + normal_scans[:70]\n",
"y_train = abnormal_labels[:70] + normal_labels[:70]\n",
"x_val = abnormal_scans[70:] + normal_scans[70:]\n",
"y_val = abnormal_labels[70:] + normal_labels[70:]\n",
"\n",
"# Define transforms\n",
"train_transform = Compose([\n",
" ToTensor(),\n",
" Lambda(lambda x: x.unsqueeze(0)) # Add channel dimension\n",
"])\n",
"\n",
"val_transform = Compose([\n",
" ToTensor(),\n",
" Lambda(lambda x: x.unsqueeze(0)) # Add channel dimension\n",
"])\n",
"\n",
"# Create datasets\n",
"train_dataset = CTScanDataset(x_train, y_train, transform=train_transform)\n",
"val_dataset = CTScanDataset(x_val, y_val, transform=val_transform)\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-06T23:31:56.579817Z",
"iopub.status.busy": "2024-06-06T23:31:56.579602Z",
"iopub.status.idle": "2024-06-06T23:31:56.585224Z",
"shell.execute_reply": "2024-06-06T23:31:56.584347Z",
"shell.execute_reply.started": "2024-06-06T23:31:56.579798Z"
},
"id": "9sZRwCDnMrq0"
},
"outputs": [],
"source": [
"# # Define data loaders.\n",
"\n",
"batch_size = 2\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)\n",
"val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OUvROCKOMrq2"
},
"source": [
"## Define a 3D convolutional neural network\n",
"\n",
"To make the model easier to understand, we structure it into blocks.\n",
"The architecture of the 3D CNN used in this example\n",
"is based on [this paper](https://arxiv.org/abs/2007.13224)."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"execution": {
"iopub.execute_input": "2024-06-07T00:49:47.520209Z",
"iopub.status.busy": "2024-06-07T00:49:47.519573Z",
"iopub.status.idle": "2024-06-07T00:49:47.578952Z",
"shell.execute_reply": "2024-06-07T00:49:47.578312Z",
"shell.execute_reply.started": "2024-06-07T00:49:47.520182Z"
},
"id": "fag8KMXJMrq2",
"outputId": "35fe5edd-3ddf-4b0c-f3aa-8057e6c0b818"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"torch.cuda.empty_cache()\n",
"torch.set_grad_enabled(True)\n",
"class CNN3D(nn.Module):\n",
" def __init__(self):\n",
" super(CNN3D, self).__init__()\n",
" self.conv1 = nn.Conv3d(1, 64, kernel_size=3, padding=1)\n",
" self.pool1 = nn.MaxPool3d(kernel_size=2)\n",
" self.bn1 = nn.BatchNorm3d(64)\n",
"\n",
" self.conv2 = nn.Conv3d(64, 64, kernel_size=3, padding=1)\n",
" self.pool2 = nn.MaxPool3d(kernel_size=2)\n",
" self.bn2 = nn.BatchNorm3d(64)\n",
"\n",
" self.conv3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)\n",
" self.pool3 = nn.MaxPool3d(kernel_size=2)\n",
" self.bn3 = nn.BatchNorm3d(128)\n",
"\n",
" self.conv4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)\n",
" self.pool4 = nn.MaxPool3d(kernel_size=2)\n",
" self.bn4 = nn.BatchNorm3d(256)\n",
"\n",
" self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))\n",
" self.fc1 = nn.Linear(256, 512)\n",
" self.dropout = nn.Dropout(0.3)\n",
" self.fc2 = nn.Linear(512, 2)\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.bn1(self.pool1(self.conv1(x))))\n",
" x = F.relu(self.bn2(self.pool2(self.conv2(x))))\n",
" x = F.relu(self.bn3(self.pool3(self.conv3(x))))\n",
" x = F.relu(self.bn4(self.pool4(self.conv4(x))))\n",
" x = self.global_pool(x)\n",
" x = x.view(-1, 256)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.dropout(x)\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Initialize model\n",
"model = CNN3D().to(device)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aZYZDwGpMrq2"
},
"source": [
"## Train model"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"execution": {
"iopub.execute_input": "2024-06-07T01:26:31.132179Z",
"iopub.status.busy": "2024-06-07T01:26:31.131897Z",
"iopub.status.idle": "2024-06-07T01:27:55.260805Z",
"shell.execute_reply": "2024-06-07T01:27:55.259727Z",
"shell.execute_reply.started": "2024-06-07T01:26:31.132157Z"
},
"id": "o_i5hPGeMrq2",
"outputId": "315b6b5f-ec48-4080-d7b6-ec0d9ff2635d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10, Train Loss: 0.6650, Train Acc: 0.6071, Val Loss: 0.6503, Val Acc: 0.6167\n",
"Epoch 2/10, Train Loss: 0.6673, Train Acc: 0.6071, Val Loss: 0.6487, Val Acc: 0.6167\n",
"Epoch 3/10, Train Loss: 0.6382, Train Acc: 0.6500, Val Loss: 0.6504, Val Acc: 0.6333\n",
"Epoch 4/10, Train Loss: 0.6605, Train Acc: 0.5857, Val Loss: 0.6535, Val Acc: 0.6500\n",
"Epoch 5/10, Train Loss: 0.6209, Train Acc: 0.6857, Val Loss: 0.6565, Val Acc: 0.6333\n",
"Epoch 6/10, Train Loss: 0.6540, Train Acc: 0.6286, Val Loss: 0.6532, Val Acc: 0.6500\n",
"Epoch 7/10, Train Loss: 0.6535, Train Acc: 0.6071, Val Loss: 0.6544, Val Acc: 0.6500\n",
"Epoch 8/10, Train Loss: 0.6450, Train Acc: 0.6429, Val Loss: 0.6557, Val Acc: 0.6667\n",
"Epoch 9/10, Train Loss: 0.6336, Train Acc: 0.6143, Val Loss: 0.6551, Val Acc: 0.6500\n",
"Epoch 10/10, Train Loss: 0.6414, Train Acc: 0.6214, Val Loss: 0.6590, Val Acc: 0.6000\n"
]
}
],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
"scheduler = ExponentialLR(optimizer, gamma=0.96)\n",
"\n",
"num_epochs = 10\n",
"best_val_acc = 0\n",
"\n",
"history = {\n",
" 'train_loss': [],\n",
" 'val_loss': [],\n",
" 'train_acc': [],\n",
" 'val_acc': []\n",
"}\n",
"\n",
"for epoch in range(num_epochs):\n",
" model.train()\n",
" train_loss = 0.0\n",
" train_correct = 0\n",
"\n",
" for inputs, labels in train_loader:\n",
" inputs, labels = inputs.to(device), labels.to(device, dtype=torch.long)\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" scheduler.step()\n",
" train_loss += loss.item() * inputs.size(0)\n",
" train_correct += (outputs.argmax(dim=1) == labels).sum().item()\n",
"\n",
" train_loss = train_loss / len(train_loader.dataset)\n",
" train_acc = train_correct / len(train_loader.dataset)\n",
"\n",
" model.eval()\n",
" val_loss = 0.0\n",
" val_correct = 0\n",
"\n",
" # with torch.no_grad():\n",
" for inputs, labels in val_loader:\n",
" inputs, labels = inputs.to(device), labels.to(device, dtype=torch.long)\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" val_loss += loss.item() * inputs.size(0)\n",
" val_correct += (outputs.argmax(dim=1) == labels).sum().item()\n",
"\n",
" val_loss = val_loss / len(val_loader.dataset)\n",
" val_acc = val_correct / len(val_loader.dataset)\n",
"\n",
" if val_acc > best_val_acc:\n",
" best_val_acc = val_acc\n",
" torch.save(model.state_dict(), \"3d_image_classification.pth\")\n",
"\n",
" # After training loop, append the performance metrics to history for visualization\n",
" history['train_loss'].append(train_loss)\n",
" history['val_loss'].append(val_loss)\n",
" history['train_acc'].append(train_acc)\n",
" history['val_acc'].append(val_acc)\n",
"\n",
" print(f\"Epoch {epoch+1}/{num_epochs}, \"\n",
" f\"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, \"\n",
" f\"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QGcSetZQMrq2"
},
"source": [
"It is important to note that the number of samples is very small (only 200) and we don't\n",
"specify a random seed. As such, you can expect significant variance in the results. The full dataset\n",
"which consists of over 1000 CT scans can be found [here](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1). Using the full\n",
"dataset, an accuracy of 83% was achieved. A variability of 6-7% in the classification\n",
"performance is observed in both cases."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AkNhKZMLMrq2"
},
"source": [
"## Visualizing model performance\n",
"\n",
"Here the model accuracy and loss for the training and the validation sets are plotted.\n",
"Since the validation set is class-balanced, accuracy provides an unbiased representation\n",
"of the model's performance."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-07T01:29:02.638537Z",
"iopub.status.busy": "2024-06-07T01:29:02.638212Z",
"iopub.status.idle": "2024-06-07T01:29:02.923931Z",
"shell.execute_reply": "2024-06-07T01:29:02.923107Z",
"shell.execute_reply.started": "2024-06-07T01:29:02.638516Z"
},
"id": "xwMdwWJUMrq3"
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 2000x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 2, figsize=(20, 5))\n",
"ax[0].plot(history['train_loss'], label='Train Loss')\n",
"ax[0].plot(history['val_loss'], label='Val Loss')\n",
"ax[0].set_title('Loss')\n",
"ax[0].legend()\n",
"\n",
"ax[1].plot(history['train_acc'], label='Train Acc')\n",
"ax[1].plot(history['val_acc'], label='Val Acc')\n",
"ax[1].set_title('Accuracy')\n",
"ax[1].legend()\n",
"\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WUeJTtXPMrq3"
},
"source": [
"## Make predictions on a single CT scan"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-07T01:29:06.370849Z",
"iopub.status.busy": "2024-06-07T01:29:06.370249Z",
"iopub.status.idle": "2024-06-07T01:29:06.474105Z",
"shell.execute_reply": "2024-06-07T01:29:06.473481Z",
"shell.execute_reply.started": "2024-06-07T01:29:06.370821Z"
},
"id": "g3YeOerGMrq3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"logits: tensor([[-0.5714, 0.4936]], device='cuda:0', grad_fn=<AddmmBackward0>)\n",
"probabilities: tensor([[0.2564, 0.7436]], device='cuda:0', grad_fn=<SoftmaxBackward0>)\n",
"[[('normal', 25.636297464370728), ('abnormal', 74.36370253562927)]]\n"
]
}
],
"source": [
"model.load_state_dict(torch.load(\"3d_image_classification.pth\"))\n",
"\n",
"model.eval()\n",
"\n",
"class_names = [\"normal\", \"abnormal\"]\n",
"predictions= []\n",
"# with torch.no_grad():\n",
"sample = torch.tensor(x_val[0]).unsqueeze(0).unsqueeze(0).to(device, dtype=torch.float32)\n",
"logits = model(sample)\n",
"print(\"logits: \", logits)\n",
"probabilities = F.softmax(logits, dim=1)\n",
"print(\"probabilities: \", probabilities)\n",
"label_confidences = [(class_names[i], prob.item() * 100) for i, prob in enumerate(probabilities.squeeze())]\n",
"predictions.append(label_confidences)\n",
"print(predictions)\n"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-07T01:53:40.889576Z",
"iopub.status.busy": "2024-06-07T01:53:40.888978Z",
"iopub.status.idle": "2024-06-07T01:53:41.050619Z",
"shell.execute_reply": "2024-06-07T01:53:41.049796Z",
"shell.execute_reply.started": "2024-06-07T01:53:40.889550Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n",
"Input tensor requires grad: torch.Size([1, 1, 128, 128, 64])\n"
]
},
{
"ename": "error",
"evalue": "OpenCV(4.8.0) :-1: error: (-5:Bad argument) in function 'resize'\n> Overload resolution failed:\n> - Can't parse 'dsize'. Expected sequence length 2, got 3\n> - Can't parse 'dsize'. Expected sequence length 2, got 3\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31merror\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[51], line 74\u001b[0m\n\u001b[1;32m 72\u001b[0m sample \u001b[38;5;241m=\u001b[39m x_val[i]\n\u001b[1;32m 73\u001b[0m label \u001b[38;5;241m=\u001b[39m y_val[i]\n\u001b[0;32m---> 74\u001b[0m \u001b[43mvisualize_gradcam\u001b[49m\u001b[43m(\u001b[49m\u001b[43msample\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcam\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclass_names\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[51], line 54\u001b[0m, in \u001b[0;36mvisualize_gradcam\u001b[0;34m(sample, label, cam, model, class_names)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput tensor requires grad:\u001b[39m\u001b[38;5;124m\"\u001b[39m, sample\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 53\u001b[0m \u001b[38;5;66;03m# Generate CAM\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m grayscale_cam \u001b[38;5;241m=\u001b[39m \u001b[43mcam\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tensor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtargets\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m, :]\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# Convert sample to numpy array for visualization\u001b[39;00m\n\u001b[1;32m 57\u001b[0m sample_np \u001b[38;5;241m=\u001b[39m sample\u001b[38;5;241m.\u001b[39msqueeze()\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n",
"File \u001b[0;32m/usr/local/lib/python3.11/dist-packages/pytorch_grad_cam/base_cam.py:186\u001b[0m, in \u001b[0;36mBaseCAM.__call__\u001b[0;34m(self, input_tensor, targets, aug_smooth, eigen_smooth)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m aug_smooth \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward_augmentation_smoothing(input_tensor, targets, eigen_smooth)\n\u001b[0;32m--> 186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meigen_smooth\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/usr/local/lib/python3.11/dist-packages/pytorch_grad_cam/base_cam.py:110\u001b[0m, in \u001b[0;36mBaseCAM.forward\u001b[0;34m(self, input_tensor, targets, eigen_smooth)\u001b[0m\n\u001b[1;32m 99\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward(retain_graph\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 101\u001b[0m \u001b[38;5;66;03m# In most of the saliency attribution papers, the saliency is\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;66;03m# computed with a single target layer.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;66;03m# Commonly it is the last convolutional layer.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;66;03m# use all conv layers for example, all Batchnorm layers,\u001b[39;00m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;66;03m# or something else.\u001b[39;00m\n\u001b[0;32m--> 110\u001b[0m cam_per_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_cam_per_layer\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meigen_smooth\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maggregate_multi_layers(cam_per_layer)\n",
"File \u001b[0;32m/usr/local/lib/python3.11/dist-packages/pytorch_grad_cam/base_cam.py:143\u001b[0m, in \u001b[0;36mBaseCAM.compute_cam_per_layer\u001b[0;34m(self, input_tensor, targets, eigen_smooth)\u001b[0m\n\u001b[1;32m 141\u001b[0m cam \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)\n\u001b[1;32m 142\u001b[0m cam \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmaximum(cam, \u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m--> 143\u001b[0m scaled \u001b[38;5;241m=\u001b[39m \u001b[43mscale_cam_image\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcam\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 144\u001b[0m cam_per_target_layer\u001b[38;5;241m.\u001b[39mappend(scaled[:, \u001b[38;5;28;01mNone\u001b[39;00m, :])\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m cam_per_target_layer\n",
"File \u001b[0;32m/usr/local/lib/python3.11/dist-packages/pytorch_grad_cam/utils/image.py:172\u001b[0m, in \u001b[0;36mscale_cam_image\u001b[0;34m(cam, target_size)\u001b[0m\n\u001b[1;32m 169\u001b[0m img \u001b[38;5;241m=\u001b[39m zoom(np\u001b[38;5;241m.\u001b[39mfloat32(img), [\n\u001b[1;32m 170\u001b[0m (t_s \u001b[38;5;241m/\u001b[39m i_s) \u001b[38;5;28;01mfor\u001b[39;00m i_s, t_s \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(img\u001b[38;5;241m.\u001b[39mshape, target_size[::\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])])\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 172\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[43mcv2\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat32\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 174\u001b[0m result\u001b[38;5;241m.\u001b[39mappend(img)\n\u001b[1;32m 175\u001b[0m result \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mfloat32(result)\n",
"\u001b[0;31merror\u001b[0m: OpenCV(4.8.0) :-1: error: (-5:Bad argument) in function 'resize'\n> Overload resolution failed:\n> - Can't parse 'dsize'. Expected sequence length 2, got 3\n> - Can't parse 'dsize'. Expected sequence length 2, got 3\n"
]
}
],
"source": [
"import torch.nn.functional as F\n",
"from pytorch_grad_cam import GradCAM\n",
"from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n",
"from pytorch_grad_cam.utils.image import show_cam_on_image\n",
"import matplotlib.pyplot as plt\n",
"import nibabel as nib\n",
"import numpy as np\n",
"from scipy import ndimage\n",
"torch.cuda.empty_cache()\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"torch.set_grad_enabled(True)\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Create the model\n",
"model = CNN3D().to(device)\n",
"model.load_state_dict(torch.load('3d_image_classification.pth', map_location=device))\n",
"model.eval()\n",
"\n",
"\n",
"# def reshape_transform(tensor, height=8, width=8):\n",
"# result = tensor.permute(0, 2, 1, 3, 4).reshape(tensor.size(0), height * width, tensor.size(1), tensor.size(3), tensor.size(4))\n",
"# return result\n",
"\n",
"def reshape_transform(tensor):\n",
" # Permute and reshape tensor for GradCAM visualization\n",
" return tensor.permute(0, 2, 3, 4, 1).contiguous()\n",
"\n",
"\n",
"# Create the Grad-CAM object\n",
"target_layer = model.conv4 # Target the last convolutional layer\n",
"cam = GradCAM(model=model, target_layers=[target_layer], reshape_transform=reshape_transform)\n",
"\n",
"\n",
"# Load the input 3D image\n",
"# Load the NIfTI file\n",
"# output_dir = os.path.join(os.getcwd(), \"MosMedData\")\n",
"# input_nifti_path = os.path.join(output_dir, \"CT-23\", \"study_0939.nii.gz\")\n",
"# output_nifti_path =os.path.join(output_dir, \"output\", \"study_0939.nii.gz\")\n",
"\n",
"\n",
"def visualize_gradcam(sample, label, cam, model, class_names):\n",
" model.eval()\n",
" sample = torch.tensor(sample).unsqueeze(0).unsqueeze(0).to(device, dtype=torch.float32)\n",
" sample.requires_grad = True\n",
" sample.retain_grad()\n",
" # Forward pass\n",
" output = model(sample)\n",
" probs = torch.sigmoid(output).squeeze()\n",
" predicted_class = torch.argmax(probs).item()\n",
" \n",
" targets = [ClassifierOutputTarget(predicted_class)]\n",
" print(\"Input tensor requires grad:\", sample.shape)\n",
" # Generate CAM\n",
" grayscale_cam = cam(input_tensor=sample, targets=targets)[0, :]\n",
" \n",
" # Convert sample to numpy array for visualization\n",
" sample_np = sample.squeeze().cpu().numpy()\n",
" \n",
" # Show CAM on image\n",
" visualization = show_cam_on_image(sample_np, grayscale_cam, use_rgb=True)\n",
" \n",
" plt.imshow(visualization)\n",
" plt.title(f'Predicted: {class_names[predicted_class]}, Label: {class_names[label]}')\n",
" plt.axis('off')\n",
" plt.show()\n",
"\n",
" \n",
"class_names = [\"normal\", \"abnormal\"]\n",
"\n",
"#with torch.no_grad():\n",
"for i in range(5): # Visualize the first 5 samples\n",
" sample = x_val[i]\n",
" label = y_val[i]\n",
" visualize_gradcam(sample, label, cam, model, class_names)\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"eYck_2P1Mrqz",
"965Iih1gMrq0",
"OUvROCKOMrq2"
],
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment