Skip to content

Instantly share code, notes, and snippets.

@iliasprc
Last active April 6, 2020 12:09
Show Gist options
  • Save iliasprc/598e93ec50fe84f7953eef359d715916 to your computer and use it in GitHub Desktop.
Save iliasprc/598e93ec50fe84f7953eef359d715916 to your computer and use it in GitHub Desktop.
COVIDNet.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "COVIDNet.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"mount_file_id": "1LAgYVoEBGzFSuv2Go2PxrPvYh6euAXvQ",
"authorship_tag": "ABX9TyNeiqeyhfUq0UP8iMDtT4u7",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/IliasPap/598e93ec50fe84f7953eef359d715916/covidnet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-TJJOOlgfUS5",
"colab_type": "text"
},
"source": [
"# Covidx dataset extraction and training with dl models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_ZM9zRc_hfZ",
"colab_type": "text"
},
"source": [
"## clone githubs and unzip the two datasets"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kG5w1SXpEutA",
"colab_type": "code",
"outputId": "7eeb968b-f9e9-4360-80d8-efcf80c0862f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 235
}
},
"source": [
"! git clone https://github.com/ieee8023/covid-chestxray-dataset.git\n",
"! git clone https://github.com/IliasPap/COVID-Net.git\n",
"\n",
"COPY_FILE = True\n",
"\n",
"# # !pip install pydicom\n",
"# ! pip install -q kaggle\n",
"# ! mkdir ~/.kaggle\n",
"\n",
"# ! pip install kaggle==1.5.6\n",
"# ! cp kaggle.json ~/.kaggle/\n",
"# ! chmod 600 ~/.kaggle/kaggle.json\n",
"\n",
"\n",
"# ! kaggle competitions download -c rsna-pneumonia-detection-challenge"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'covid-chestxray-dataset'...\n",
"remote: Enumerating objects: 1079, done.\u001b[K\n",
"remote: Total 1079 (delta 0), reused 0 (delta 0), pack-reused 1079\u001b[K\n",
"Receiving objects: 100% (1079/1079), 188.60 MiB | 32.86 MiB/s, done.\n",
"Resolving deltas: 100% (492/492), done.\n",
"Checking out files: 100% (270/270), done.\n",
"Cloning into 'COVID-Net'...\n",
"remote: Enumerating objects: 7, done.\u001b[K\n",
"remote: Counting objects: 100% (7/7), done.\u001b[K\n",
"remote: Compressing objects: 100% (6/6), done.\u001b[K\n",
"remote: Total 202 (delta 1), reused 5 (delta 1), pack-reused 195\u001b[K\n",
"Receiving objects: 100% (202/202), 3.20 MiB | 6.86 MiB/s, done.\n",
"Resolving deltas: 100% (111/111), done.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VL4jqTfNp59A",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_rxYBDNsXL1Z",
"colab_type": "text"
},
"source": [
"## KAGGLE dataset from google drive"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1cavJSZ5XOtG",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "3598dfc5-4cea-49d2-b817-95bb86aeea77"
},
"source": [
"! mkdir /content/rsna_dataset\n",
"! unzip '/content/drive/My Drive/MEDICAL/rsna-pneumonia-detection-challenge.zip' -d /content/rsna_dataset/"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[1;30;43mStreaming output truncated to the last 5000 lines.\u001b[0m\n",
" inflating: /content/rsna_dataset/stage_2_train_images/34bf2fcd-131a-428c-9a21-cd2fa9041f9b.dcm "
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "z5wJrNmlhvjT",
"colab_type": "code",
"colab": {}
},
"source": [
"! pip install pydicom\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import random \n",
"from shutil import copyfile\n",
"import pydicom as dicom\n",
"import cv2"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SWIYPPL9h1xZ",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"seed = 0\n",
"np.random.seed(seed) # Reset the seed so all runs are the same.\n",
"random.seed(seed)\n",
"MAXVAL = 255 # Range [0 255]\n",
"root = '/content/covid-chestxray-dataset'\n",
"\n",
"if (COPY_FILE):\n",
" savepath = root + '/data'\n",
" if(not os.path.exists(savepath)):\n",
" os.makedirs(savepath)\n",
" savepath = root + '/data/train'\n",
" if(not os.path.exists(savepath)):\n",
" os.makedirs(savepath)\n",
" savepath = root + '/data/test'\n",
" if(not os.path.exists(savepath)):\n",
" os.makedirs(savepath)\n",
"\n",
"savepath = root + '/data'\n",
"# path to covid-19 dataset from https://github.com/ieee8023/covid-chestxray-dataset\n",
"imgpath = root + '/images' \n",
"csvpath = root + '/metadata.csv'\n",
"\n",
"# path to https://www.kaggle.com/c/rsna-pneumonia-detection-challenge\n",
"kaggle_datapath = '/content/rsna_kaggle_dataset'\n",
"kaggle_csvname = 'stage_2_detailed_class_info.csv' # get all the normal from here\n",
"kaggle_csvname2 = 'stage_2_train_labels.csv' # get all the 1s from here since 1 indicate pneumonia\n",
"kaggle_imgpath = 'stage_2_train_images'\n",
"\n",
"# parameters for COVIDx dataset\n",
"train = []\n",
"test = []\n",
"test_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}\n",
"train_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}\n",
"\n",
"mapping = dict()\n",
"mapping['COVID-19'] = 'COVID-19'\n",
"mapping['SARS'] = 'pneumonia'\n",
"mapping['MERS'] = 'pneumonia'\n",
"mapping['Streptococcus'] = 'pneumonia'\n",
"mapping['Normal'] = 'normal'\n",
"mapping['Lung Opacity'] = 'pneumonia'\n",
"mapping['1'] = 'pneumonia'\n",
"\n",
"# train/test split\n",
"split = 0.1"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0A40xmYriH6t",
"colab_type": "code",
"colab": {}
},
"source": [
"# adapted from https://github.com/mlmed/torchxrayvision/blob/master/torchxrayvision/datasets.py#L814\n",
"csv = pd.read_csv(csvpath, nrows=None)\n",
"idx_pa = csv[\"view\"] == \"PA\" # Keep only the PA view\n",
"csv = csv[idx_pa]\n",
"\n",
"pneumonias = [\"COVID-19\", \"SARS\", \"MERS\", \"ARDS\", \"Streptococcus\"]\n",
"pathologies = [\"Pneumonia\",\"Viral Pneumonia\", \"Bacterial Pneumonia\", \"No Finding\"] + pneumonias\n",
"pathologies = sorted(pathologies)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "65fgnk8oXaDd",
"colab_type": "text"
},
"source": [
"## Data distribution covid-chestxray-dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3oCGULcBibv-",
"colab_type": "code",
"colab": {}
},
"source": [
"# get non-COVID19 viral, bacteria, and COVID-19 infections from covid-chestxray-dataset\n",
"# stored as patient id, image filename and label\n",
"filename_label = {'normal': [], 'pneumonia': [], 'COVID-19': []}\n",
"count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}\n",
"print(csv.keys())\n",
"for index, row in csv.iterrows():\n",
" f = row['finding']\n",
" if f in mapping:\n",
" count[mapping[f]] += 1\n",
" entry = [int(row['patientid']), row['filename'], mapping[f]]\n",
" filename_label[mapping[f]].append(entry)\n",
"\n",
"print('Data distribution from covid-chestxray-dataset:')\n",
"print(count)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "LaN5fKLLXji2",
"colab_type": "text"
},
"source": [
"## add covid-chestxray-dataset into COVIDx datase"
]
},
{
"cell_type": "code",
"metadata": {
"id": "QZXOLpHZiz3q",
"colab_type": "code",
"colab": {}
},
"source": [
"# add covid-chestxray-dataset into COVIDx dataset\n",
"# since covid-chestxray-dataset doesn't have test dataset\n",
"# split into train/test by patientid\n",
"# for COVIDx:\n",
"# patient 8 is used as non-COVID19 viral test\n",
"# patient 31 is used as bacterial test\n",
"# patients 19, 20, 36, 42, 86 are used as COVID-19 viral test\n",
"\n",
"for key in filename_label.keys():\n",
" arr = np.array(filename_label[key])\n",
" if arr.size == 0:\n",
" continue\n",
" # split by patients\n",
" # num_diff_patients = len(np.unique(arr[:,0]))\n",
" # num_test = max(1, round(split*num_diff_patients))\n",
" # select num_test number of random patients\n",
" if key == 'pneumonia':\n",
" test_patients = ['8', '31']\n",
" elif key == 'COVID-19':\n",
" test_patients = ['19', '20', '36', '42', '86'] # random.sample(list(arr[:,0]), num_test)\n",
" else: \n",
" test_patients = []\n",
" print('Key: ', key)\n",
" print('Test patients: ', test_patients)\n",
" # go through all the patients\n",
" for patient in arr:\n",
" if patient[0] in test_patients:\n",
" if (COPY_FILE):\n",
" copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'test', patient[1]))\n",
" test.append(patient)\n",
" test_count[patient[2]] += 1\n",
" else:\n",
" print(\"WARNING : passing copy file !!!!!!!!!!!!!!!!!!!!!!\")\n",
" break\n",
" else:\n",
" if (COPY_FILE):\n",
" copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'train', patient[1]))\n",
" train.append(patient)\n",
" train_count[patient[2]] += 1\n",
"\n",
" else:\n",
" print(\"WARNING : passing copy file !!!!!!!!!!!!!!!!!!!!!!\")\n",
" break\n",
"\n",
"print('test count: ', test_count)\n",
"print('train count: ', train_count)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "HW2HHFi1Dra2",
"colab_type": "text"
},
"source": [
"## Copy kaggle data to train and test folders"
]
},
{
"cell_type": "code",
"metadata": {
"id": "PG-qk6gXjZIB",
"colab_type": "code",
"colab": {}
},
"source": [
"# add normal and rest of pneumonia cases from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge\n",
"\n",
"\n",
"kaggle_datapath = '/content/rsna_dataset'\n",
"\n",
"print(kaggle_datapath)\n",
"csv_normal = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname), nrows=None)\n",
"csv_pneu = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname2), nrows=None)\n",
"patients = {'normal': [], 'pneumonia': []}\n",
"\n",
"for index, row in csv_normal.iterrows():\n",
" if row['class'] == 'Normal':\n",
" patients['normal'].append(row['patientId'])\n",
"\n",
"for index, row in csv_pneu.iterrows():\n",
" if int(row['Target']) == 1:\n",
" patients['pneumonia'].append(row['patientId'])\n",
"\n",
"for key in patients.keys():\n",
" arr = np.array(patients[key])\n",
" if arr.size == 0:\n",
" continue\n",
" # split by patients \n",
" # num_diff_patients = len(np.unique(arr))\n",
" # num_test = max(1, round(split*num_diff_patients))\n",
" #'/content/COVID-Net/'\n",
" test_patients = np.load('/content/COVID-Net/rsna_test_patients_{}.npy'.format(key)) # random.sample(list(arr), num_test)\n",
" # np.save('rsna_test_patients_{}.npy'.format(key), np.array(test_patients))\n",
" for patient in arr:\n",
" ds = dicom.dcmread(os.path.join(kaggle_datapath, kaggle_imgpath, patient + '.dcm'))\n",
" pixel_array_numpy = ds.pixel_array\n",
" imgname = patient + '.png'\n",
" if patient in test_patients:\n",
" if (COPY_FILE):\n",
" cv2.imwrite(os.path.join(savepath, 'test', imgname), pixel_array_numpy)\n",
" test.append([patient, imgname, key])\n",
" test_count[key] += 1\n",
" else:\n",
" print(\"WARNING : passing copy file !!!!!!!!!!!!!!!!!!!!!!\")\n",
" break\n",
" else:\n",
" if (COPY_FILE):\n",
" cv2.imwrite(os.path.join(savepath, 'train', imgname), pixel_array_numpy)\n",
" train.append([patient, imgname, key])\n",
" train_count[key] += 1\n",
" else:\n",
" print(\"WARNING : passing copy file !!!!!!!!!!!!!!!!!!!!!!\")\n",
" break\n",
"print('test count: ', test_count)\n",
"print('train count: ', train_count)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "9DQ4gYQkDxt_",
"colab_type": "text"
},
"source": [
"## Final data stats"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rofg8ddlX28e",
"colab_type": "code",
"colab": {}
},
"source": [
"# final stats\n",
"print('Final stats')\n",
"print('Train count: ', train_count)\n",
"print('Test count: ', test_count)\n",
"print('Total length of train: ', len(train))\n",
"print('Total length of test: ', len(test))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "zruYtB-8D1y4",
"colab_type": "text"
},
"source": [
"## Train and test file extraction"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xh_E3s-AX3sm",
"colab_type": "code",
"colab": {}
},
"source": [
"# export to train and test csv\n",
"# format as patientid, filename, label, separated by a space\n",
"train_file = open(\"train_split_v2.txt\",\"w\") \n",
"for sample in train:\n",
" info = str(sample[0]) + ' ' + sample[1] + ' ' + sample[2] + '\\n'\n",
" train_file.write(info)\n",
"\n",
"train_file.close()\n",
"\n",
"test_file = open(\"test_split_v2.txt\", \"w\")\n",
"for sample in test:\n",
" info = str(sample[0]) + ' ' + sample[1] + ' ' + sample[2] + '\\n'\n",
" test_file.write(info)\n",
"\n",
"test_file.close()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QuSxkmuebDuv",
"colab_type": "code",
"colab": {}
},
"source": [
"# import glob\n",
"\n",
"# images = glob.glob('/content/drive/My Drive/MEDICAL/data/*/*')\n",
"# print(len(images))\n",
"\n",
"# train = glob.glob('/content/drive/My Drive/MEDICAL/data/train/*')\n",
"# test = glob.glob('/content/drive/My Drive/MEDICAL/data/test/*')\n",
"# print(len(train))\n",
"# print(len(test))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "YrMN6spNXyEy",
"colab_type": "text"
},
"source": [
"# Training on Covidx dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tJqeEfwh0CH1",
"colab_type": "text"
},
"source": [
"## Training imports\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "AvCMyZ_m7LN7",
"colab_type": "code",
"colab": {}
},
"source": [
"! pip install torchsummaryX\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "bef7LpK60EsE",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import Dataset\n",
"import torch.optim as optim\n",
"import torch.nn as nn\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
"from torchvision import transforms,models\n",
"from torchsummaryX import summary\n",
"import numpy as np\n",
"\n",
"\n",
"import argparse\n",
"import csv\n",
"from PIL import Image\n",
"\n",
"\n",
"import os\n",
"import shutil\n",
"import time\n",
"from collections import OrderedDict\n",
"import json"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "d_pu8yV5AN1i",
"colab_type": "text"
},
"source": [
"## Utils "
]
},
{
"cell_type": "code",
"metadata": {
"id": "1S1MSMhTAPkz",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"def write_score(writer, iter, mode, metrics):\n",
" writer.add_scalar(mode + '/loss', metrics.data['loss'], iter)\n",
" writer.add_scalar(mode + '/acc', metrics.data['correct'] / metrics.data['total'], iter)\n",
"\n",
"\n",
"def write_train_val_score(writer, epoch, train_stats, val_stats):\n",
" writer.add_scalars('Loss', {'train': train_stats[0],\n",
" 'val': val_stats[0],\n",
" }, epoch)\n",
" writer.add_scalars('Coeff', {'train': train_stats[1],\n",
" 'val': val_stats[1],\n",
" }, epoch)\n",
"\n",
" writer.add_scalars('Air', {'train': train_stats[2],\n",
" 'val': val_stats[2],\n",
" }, epoch)\n",
"\n",
" writer.add_scalars('CSF', {'train': train_stats[3],\n",
" 'val': val_stats[3],\n",
" }, epoch)\n",
" writer.add_scalars('GM', {'train': train_stats[4],\n",
" 'val': val_stats[4],\n",
" }, epoch)\n",
" writer.add_scalars('WM', {'train': train_stats[5],\n",
" 'val': val_stats[5],\n",
" }, epoch)\n",
" return\n",
"\n",
"\n",
"def showgradients(model):\n",
" for param in model.parameters():\n",
" print(type(param.data), param.size())\n",
" print(\"GRADS= \\n\", param.grad)\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"def datestr():\n",
" now = time.gmtime()\n",
" return '{}{:02}{:02}_{:02}{:02}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min)\n",
"\n",
"\n",
"def save_checkpoint(state, is_best, path, filename='last'):\n",
"\n",
" name = os.path.join(path, filename+'_checkpoint.pth.tar')\n",
" print(name)\n",
" torch.save(state, name)\n",
"\n",
"\n",
"\n",
"def save_model(model,optimizer, args, metrics, epoch, best_pred_loss,confusion_matrix):\n",
" loss = metrics.data['loss']\n",
" save_path = args.save\n",
" make_dirs(save_path)\n",
" \n",
" with open(save_path + '/training_arguments.txt', 'w') as f:\n",
" json.dump(args.__dict__, f, indent=2)\n",
" \n",
" is_best = False\n",
" if loss < best_pred_loss:\n",
" is_best = True\n",
" best_pred_loss = loss\n",
" save_checkpoint({'epoch': epoch,\n",
" 'state_dict': model.state_dict(),\n",
" 'optimizer': optimizer.state_dict(),\n",
" 'metrics': metrics.data },\n",
" is_best, save_path, args.model + \"_best\")\n",
" np.save(save_path + '/best_confusion_matrix.npy',confusion_matrix.cpu().numpy())\n",
" \n",
" else:\n",
" save_checkpoint({'epoch': epoch,\n",
" 'state_dict': model.state_dict(),\n",
" 'optimizer': optimizer.state_dict(),\n",
" 'metrics': metrics.data},\n",
" False, save_path, args.model + \"_last\")\n",
"\n",
" return best_pred_loss\n",
"\n",
"\n",
"def make_dirs(path):\n",
" if not os.path.exists(path):\n",
"\n",
" os.makedirs(path)\n",
"\n",
"\n",
"def create_stats_files(path):\n",
" train_f = open(os.path.join(path, 'train.csv'), 'w')\n",
" val_f = open(os.path.join(path, 'val.csv'), 'w')\n",
" return train_f, val_f\n",
"\n",
"\n",
"def read_json_file(fname):\n",
" with open(fname, 'r') as handle:\n",
" return json.load(handle, object_hook=OrderedDict)\n",
"\n",
"\n",
"def write_json_file(content, fname):\n",
" with open(fname, 'w') as handle:\n",
" json.dump(content, handle, indent=4, sort_keys=False)\n",
"\n",
"\n",
"def read_filepaths(file):\n",
" paths, labels = [], []\n",
" with open(file, 'r') as f:\n",
" lines = f.read().splitlines()\n",
"\n",
" for idx, line in enumerate(lines):\n",
" if ('/ c o' in line):\n",
" break\n",
" subjid, path, label = line.split(' ')\n",
"\n",
" paths.append(path)\n",
" labels.append(label)\n",
" return paths, labels\n",
"\n",
"\n",
"\n",
"def select_model(args):\n",
" if args.model == 'COVIDNet_small':\n",
" return CovidNet('small', n_classes=args.classes)\n",
"\n",
" elif args.model == 'COVIDNet_large':\n",
" return CovidNet('large', n_classes=args.classes)\n",
" elif args.model == 'resnet18':\n",
" return CNN(args.classes, 'resnet18')\n",
"\n",
"\n",
"def select_optimizer(args, model):\n",
" if args.opt == 'sgd':\n",
" return optim.SGD(model.parameters(), lr=args.lr, momentum=0.5, weight_decay=args.weight_decay)\n",
" elif args.opt == 'adam':\n",
" return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n",
" elif args.opt == 'rmsprop':\n",
" return optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n",
"\n",
"\n",
"def print_stats(args, epoch, num_samples, trainloader, metrics):\n",
" if (num_samples % args.log_interval == 1):\n",
" print(\"Epoch:{:2d}\\tSample:{:5d}/{:5d}\\tLoss:{:.4f}\\tAccuracy:{:.2f}\".format(epoch,\n",
" num_samples,\n",
" len(\n",
" trainloader) * args.batch_size,\n",
" metrics.data[\n",
" 'loss'] / num_samples,\n",
" metrics.data[\n",
" 'correct'] /\n",
" metrics.data[\n",
" 'total']))\n",
"\n",
"\n",
"def print_summary(args, epoch, num_samples, metrics, mode=''):\n",
" print(mode + \"\\n SUMMARY EPOCH:{:2d}\\tSample:{:5d}/{:5d}\\tLoss:{:.4f}\\tAccuracy:{:.2f}\\n\".format(epoch,\n",
" num_samples,\n",
" num_samples ,\n",
" metrics.data[\n",
" 'loss'] / num_samples, \n",
" metrics.data[\n",
" 'correct'] /\n",
" metrics.data[\n",
" 'total']))\n",
"\n",
"\n",
"def confusion_matrix(nb_classes):\n",
"\n",
"\n",
"\n",
" confusion_matrix = torch.zeros(nb_classes, nb_classes)\n",
" with torch.no_grad():\n",
" for i, (inputs, classes) in enumerate(dataloaders['val']):\n",
" inputs = inputs.to(device)\n",
" classes = classes.to(device)\n",
" outputs = model_ft(inputs)\n",
" _, preds = torch.max(outputs, 1)\n",
" for t, p in zip(classes.view(-1), preds.view(-1)):\n",
" confusion_matrix[t.long(), p.long()] += 1\n",
"\n",
" print(confusion_matrix)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q0_fFFfn2qU2",
"colab_type": "text"
},
"source": [
"## METRICS"
]
},
{
"cell_type": "code",
"metadata": {
"id": "gQath5FI2yVd",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"\n",
"\n",
"class Metrics:\n",
" def __init__(self, path, keys=None, writer=None):\n",
" self.writer = writer\n",
"\n",
" self.data = {'correct': 0,\n",
" 'total': 0,\n",
" 'loss': 0,\n",
" 'accuracy': 0,\n",
" }\n",
" self.save_path = path\n",
"\n",
" def reset(self):\n",
" for key in self.data:\n",
" self.data[key] = 0\n",
"\n",
" def update_key(self, key, value, n=1):\n",
" if self.writer is not None:\n",
" self.writer.add_scalar(key, value)\n",
" self.data[key] += value\n",
"\n",
" def update(self, values):\n",
" for key in self.data:\n",
" self.data[key] += values[key]\n",
"\n",
" def avg_acc(self):\n",
" return self.data['correct'] / self.data['total']\n",
"\n",
" def avg_loss(self):\n",
" return self.data['loss'] / self.data['total']\n",
"\n",
" def save(self):\n",
" with open(self.save_path, 'w') as save_file:\n",
" a = 0 # csv.writer()\n",
" # TODO\n",
"\n",
"\n",
"def accuracy(output, target):\n",
" with torch.no_grad():\n",
" pred = torch.argmax(output, dim=1)\n",
" assert pred.shape[0] == len(target)\n",
" correct = 0\n",
" correct += torch.sum(pred == target).item()\n",
" return correct, len(target), correct / len(target)\n",
"\n",
"\n",
"def top_k_acc(output, target, k=3):\n",
" with torch.no_grad():\n",
" pred = torch.topk(output, k, dim=1)[1]\n",
" assert pred.shape[0] == len(target)\n",
" correct = 0\n",
" for i in range(k):\n",
" correct += torch.sum(pred[:, i] == target).item()\n",
" return correct / len(target)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ChF-utz429PK",
"colab_type": "text"
},
"source": [
"## LOSS"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_3lCmUJ32-yl",
"colab_type": "code",
"colab": {}
},
"source": [
"def nll_loss(output, target):\n",
" return F.nll_loss(output, target)\n",
"\n",
"\n",
"def crossentropy_loss(output, target):\n",
" return F.cross_entropy(output, target)\n",
"\n",
"def focal_loss(output,target):\n",
" ce_loss = F.cross_entropy(output, target, reduction='none')\n",
" #print(ce_loss.shape)\n",
" pt = torch.exp(-ce_loss)\n",
" alpha = 0.25\n",
" gamma = 2\n",
" focal_loss = (alpha * (1-pt)**gamma * ce_loss).mean() \n",
" return focal_loss"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "oej84gWUdaqx",
"colab_type": "text"
},
"source": [
"## CNN models"
]
},
{
"cell_type": "code",
"metadata": {
"id": "V5ABEsVidfCV",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"\n",
"class CNN(nn.Module):\n",
" def __init__(self,classes,model='resnet18'):\n",
" super(CNN,self).__init__()\n",
" if(model == 'resnet18'):\n",
" self.cnn = models.resnet18(pretrained=True)\n",
" self.cnn.fc = nn.Linear(512,classes)\n",
" elif (model == 'mobilenet2'):\n",
"\n",
" self.cnn = models.resnext50_32x4d(pretrained=True)\n",
" self.cnn.classifier = nn.Linear(1280,classes)\n",
" def forward (self,x):\n",
" return self.cnn(x)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hiiBN8qXfOoA",
"colab_type": "text"
},
"source": [
"## COVID-NET"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PeSXN3Th0Ba2",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "9X99y49bfODF",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"\n",
"\n",
"class Flatten(nn.Module):\n",
" def forward(self, input):\n",
" return input.view(input.size(0), -1)\n",
"\n",
"\n",
"class PEXP(nn.Module):\n",
" def __init__(self, n_input, n_out):\n",
" super(PEXP, self).__init__()\n",
"\n",
" '''\n",
" • First-stage Projection: 1×1 convolutions for projecting input features to a lower dimension,\n",
"\n",
" • Expansion: 1×1 convolutions for expanding features\n",
" to a higher dimension that is different than that of the\n",
" input features,\n",
"\n",
"\n",
" • Depth-wise Representation: efficient 3×3 depthwise convolutions for learning spatial characteristics to\n",
" minimize computational complexity while preserving\n",
" representational capacity,\n",
"\n",
" • Second-stage Projection: 1×1 convolutions for projecting features back to a lower dimension, and\n",
"\n",
" • Extension: 1×1 convolutions that finally extend channel dimensionality to a higher dimension to produce\n",
" the final features.\n",
" \n",
" # self.first_stage = nn.Conv2d(in_channels = n_input, out_channels=n_input//2, kernel_size=1)\n",
" # self.expansion = nn.Conv2d(in_channels = n_input//2, out_channels=int(3*n_input/4), kernel_size=1)\n",
" # self.dwc = nn.Conv2d(in_channels = int(3*n_input/4), out_channels=int(3*n_input/4), kernel_size=3,groups=int(3*n_input/4))\n",
" # self.second_stage = nn.Conv2d(in_channels = int(3*n_input/4), out_channels=n_input//2, kernel_size=1)\n",
" # self.expansion = nn.Conv2d(in_channels = n_input//2, out_channels=n_out, kernel_size=1)\n",
" self.network = nn.Sequential(nn.Conv2d(in_channels=n_input, out_channels=n_input // 2, kernel_size=1),\n",
"\n",
" nn.Conv2d(in_channels=n_input // 2, out_channels=int(3 * n_input / 4),\n",
" kernel_size=1),\n",
"\n",
" nn.Conv2d(in_channels=int(3 * n_input / 4), out_channels=int(3 * n_input / 4),\n",
" kernel_size=3, groups=int(3 * n_input / 4), padding=1),\n",
"\n",
" nn.Conv2d(in_channels=int(3 * n_input / 4), out_channels=n_input // 2,\n",
" kernel_size=1),\n",
"\n",
" nn.Conv2d(in_channels=n_input // 2, out_channels=n_out, kernel_size=1))\n",
" '''\n",
"\n",
"\n",
" self.network = nn.Sequential(nn.Conv2d(in_channels=n_input, out_channels=n_input // 4, kernel_size=1),\n",
"\n",
" nn.Conv2d(in_channels=n_input // 4, out_channels=n_input // 2,\n",
" kernel_size=1),\n",
"\n",
" nn.Conv2d(in_channels=n_input // 2, out_channels=n_input // 2,\n",
" kernel_size=3, groups=n_input // 2, padding=1),\n",
"\n",
" nn.Conv2d(in_channels=n_input // 2, out_channels=n_input // 4,\n",
" kernel_size=1),\n",
"\n",
" nn.Conv2d(in_channels=n_input // 4, out_channels=n_out, kernel_size=1))\n",
"\n",
" def forward(self, x):\n",
" return self.network(x)\n",
"\n",
"\n",
"class CovidNet(nn.Module):\n",
" def __init__(self, model='small',n_classes=3):\n",
" super(CovidNet, self).__init__()\n",
" filters = {\n",
" 'pexp1_1': [64, 256],\n",
" 'pexp1_2': [256, 256],\n",
" 'pexp1_3': [256, 256],\n",
" 'pexp2_1': [256, 512],\n",
" 'pexp2_2': [512, 512],\n",
" 'pexp2_3': [512, 512],\n",
" 'pexp2_4': [512, 512],\n",
" 'pexp3_1': [512, 1024],\n",
" 'pexp3_2': [1024, 1024],\n",
" 'pexp3_3': [1024, 1024],\n",
" 'pexp3_4': [1024, 1024],\n",
" 'pexp3_5': [1024, 1024],\n",
" 'pexp3_6': [1024, 1024],\n",
" 'pexp4_1': [1024, 2048],\n",
" 'pexp4_2': [2048, 2048],\n",
" 'pexp4_3': [2048, 2048],\n",
" }\n",
"\n",
"\n",
" self.add_module('conv1', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3))\n",
" for key in filters:\n",
"\n",
" if ('pool' in key):\n",
" self.add_module(key, nn.MaxPool2d(filters[key][0], filters[key][1]))\n",
" else:\n",
" self.add_module(key, PEXP(filters[key][0], filters[key][1]))\n",
"\n",
"\n",
" if(model == 'large'):\n",
" \n",
" self.add_module('conv1_1x1', nn.Conv2d(in_channels=64, out_channels=256, kernel_size=1))\n",
" self.add_module('conv2_1x1', nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1))\n",
" self.add_module('conv3_1x1', nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1))\n",
" self.add_module('conv4_1x1', nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1))\n",
"\n",
" self.__forward__ = self.forward_large_net\n",
" else:\n",
" self.__forward__ = self.forward_small_net\n",
" self.add_module('flatten', Flatten())\n",
" self.add_module('fc1', nn.Linear(7 * 7 * 2048, 1024))\n",
"\n",
" self.add_module('fc2', nn.Linear(1024, 256))\n",
" self.add_module('classifier', nn.Linear(256, n_classes))\n",
"\n",
" def forward(self,x):\n",
" return self.__forward__(x)\n",
"\n",
"\n",
" def forward_large_net(self, x):\n",
" x = F.max_pool2d(F.relu(self.conv1(x)),2)\n",
" out_conv1_1x1 = self.conv1_1x1(x)\n",
"\n",
" pepx11 = self.pexp1_1(x)\n",
" pepx12 = self.pexp1_2(pepx11 + out_conv1_1x1)\n",
" pepx13 = self.pexp1_3(pepx12 + pepx11 + out_conv1_1x1)\n",
"\n",
" out_conv2_1x1 = F.max_pool2d(self.conv2_1x1(pepx12 + pepx11 + pepx13 + out_conv1_1x1),2)\n",
"\n",
" pepx21 = self.pexp2_1(F.max_pool2d(pepx13, 2) + F.max_pool2d(pepx11, 2) + F.max_pool2d(pepx12, 2) + F.max_pool2d(out_conv1_1x1,2))\n",
" pepx22 = self.pexp2_2(pepx21 + out_conv2_1x1)\n",
" pepx23 = self.pexp2_3(pepx22 + pepx21 + out_conv2_1x1)\n",
" pepx24 = self.pexp2_4(pepx23 + pepx21 + pepx22 + out_conv2_1x1)\n",
"\n",
" out_conv3_1x1 = F.max_pool2d(self.conv3_1x1(pepx22 + pepx21 + pepx23 + pepx24 + out_conv2_1x1),2)\n",
"\n",
" pepx31 = self.pexp3_1(F.max_pool2d(pepx24, 2) + F.max_pool2d(pepx21, 2) + F.max_pool2d(pepx22,2) + F.max_pool2d(pepx23, 2) + F.max_pool2d(out_conv2_1x1,2))\n",
" pepx32 = self.pexp3_2(pepx31 + out_conv3_1x1)\n",
" pepx33 = self.pexp3_3(pepx31 + pepx32 + out_conv3_1x1)\n",
" pepx34 = self.pexp3_4(pepx31 + pepx32 + pepx33 + out_conv3_1x1)\n",
" pepx35 = self.pexp3_5(pepx31 + pepx32 + pepx33 + pepx34 + out_conv3_1x1)\n",
" pepx36 = self.pexp3_6(pepx31 + pepx32 + pepx33 + pepx34 + pepx35 + out_conv3_1x1)\n",
"\n",
" out_conv4_1x1 = F.max_pool2d(self.conv4_1x1(pepx31 + pepx32 + pepx33 + pepx34 + pepx35+ pepx36 + out_conv3_1x1),2)\n",
"\n",
" pepx41 = self.pexp4_1(F.max_pool2d(pepx31, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx34, 2)+ F.max_pool2d(pepx35, 2)+ F.max_pool2d(pepx36, 2)+ F.max_pool2d(out_conv3_1x1,2))\n",
" pepx42 = self.pexp4_2(pepx41 + out_conv4_1x1)\n",
" pepx43 = self.pexp4_3(pepx41 + pepx42 + out_conv4_1x1)\n",
" flattened = self.flatten(pepx41 + pepx42 + pepx43 + out_conv4_1x1)\n",
"\n",
" fc1out = F.relu(self.fc1(flattened))\n",
" fc2out = F.relu(self.fc2(fc1out))\n",
" logits = self.classifier(fc2out)\n",
" return logits\n",
"\n",
" def forward_small_net(self, x):\n",
" x = F.max_pool2d(F.relu(self.conv1(x)),2)\n",
"\n",
"\n",
" pepx11 = self.pexp1_1(x)\n",
" pepx12 = self.pexp1_2(pepx11 )\n",
" pepx13 = self.pexp1_3(pepx12 + pepx11 )\n",
"\n",
" \n",
"\n",
" pepx21 = self.pexp2_1(F.max_pool2d(pepx13, 2) + F.max_pool2d(pepx11, 2) + F.max_pool2d(pepx12, 2) )\n",
" pepx22 = self.pexp2_2(pepx21 )\n",
" pepx23 = self.pexp2_3(pepx22 + pepx21)\n",
" pepx24 = self.pexp2_4(pepx23 + pepx21 + pepx22 )\n",
"\n",
" \n",
"\n",
" pepx31 = self.pexp3_1(F.max_pool2d(pepx24, 2) + F.max_pool2d(pepx21, 2) + F.max_pool2d(pepx22,2) + F.max_pool2d(pepx23, 2) )\n",
" pepx32 = self.pexp3_2(pepx31)\n",
" pepx33 = self.pexp3_3(pepx31 + pepx32)\n",
" pepx34 = self.pexp3_4(pepx31 + pepx32 + pepx33)\n",
" pepx35 = self.pexp3_5(pepx31 + pepx32 + pepx33 + pepx34)\n",
" pepx36 = self.pexp3_6(pepx31 + pepx32 + pepx33 + pepx34 + pepx35)\n",
"\n",
"\n",
"\n",
" pepx41 = self.pexp4_1(F.max_pool2d(pepx31, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx34, 2)+ F.max_pool2d(pepx35, 2)+ F.max_pool2d(pepx36, 2))\n",
" pepx42 = self.pexp4_2(pepx41 )\n",
" pepx43 = self.pexp4_3(pepx41 + pepx42)\n",
" flattened = self.flatten(pepx41 + pepx42 + pepx43)\n",
"\n",
" fc1out = F.relu(self.fc1(flattened))\n",
" fc2out = F.relu(self.fc2(fc1out))\n",
" logits = self.classifier(fc2out)\n",
" return logits\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"'''\n",
" FORWARD ONLY WITH SKIP CONNECTIONS\n",
"\n",
" def forward(self, x):\n",
" x = self.pool1(self.conv1(x))\n",
" out_conv1_1x1 = self.conv1_1x1(x)\n",
"\n",
" pepx11 = self.pexp1_1(x)\n",
" pepx12 = self.pexp1_2(pepx11)\n",
" pepx13 = self.pexp1_3(pepx12 + pepx11)\n",
"\n",
" pepx21 = self.pexp2_1(F.max_pool2d(pepx13, 2) + F.max_pool2d(pepx11, 2) + F.max_pool2d(pepx12, 2))\n",
" pepx22 = self.pexp2_2(pepx21)\n",
" pepx23 = self.pexp2_3(pepx22 + pepx21)\n",
" pepx24 = self.pexp2_4(pepx23 + pepx21 + pepx22)\n",
"\n",
" pepx31 = self.pexp3_1(F.max_pool2d(pepx24, 2) + F.max_pool2d(pepx21, 2) + F.max_pool2d(pepx22,2) + F.max_pool2d(pepx23, 2))\n",
" pepx32 = self.pexp3_2(pepx31)\n",
" pepx33 = self.pexp3_3(pepx31 + pepx32)\n",
" pepx34 = self.pexp3_4(pepx31 + pepx32 + pepx33)\n",
" pepx35 = self.pexp3_5(pepx31 + pepx32 + pepx33 + pepx34)\n",
" pepx36 = self.pexp3_6(pepx31 + pepx32 + pepx33 + pepx34 + pepx35)\n",
"\n",
" pepx41 = self.pexp4_1(F.max_pool2d(pepx31, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx34, 2)+ F.max_pool2d(pepx35, 2)+ F.max_pool2d(pepx36, 2))\n",
" pepx42 = self.pexp4_2(pepx41)\n",
" pepx43 = self.pexp4_3(pepx41 + pepx42)\n",
" flattened = self.flatten(pepx41 + pepx42 + pepx43)\n",
"\n",
" fc1out = self.fc1(flattened)\n",
" fc2out = self.fc2(fc1out)\n",
" logits = self.classifier(fc2out)\n",
" return x\n",
"\n",
"\n",
"'''"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "lYX88x25vF3h",
"colab_type": "text"
},
"source": [
"## Dataloader\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "0g6lfu1-wtau",
"colab": {}
},
"source": [
"\n",
"\n",
"class COVIDxDataset(Dataset):\n",
" \"\"\"\n",
" Code for reading the COVIDxDataset\n",
" \"\"\"\n",
"\n",
" def __init__(self, mode, n_classes=3, dataset_path='./datasets', dim=(224, 224)):\n",
" self.root = str(dataset_path)+'/'+mode+'/'\n",
" \n",
" \n",
" self.CLASSES = n_classes\n",
" self.dim = dim\n",
" self.COVIDxDICT = {'pneumonia': 0, 'normal': 1, 'COVID-19': 2}\n",
" testfile = '/content/test_split_v2.txt'\n",
" trainfile = '/content/train_split_v2.txt'\n",
" if (mode == 'train'):\n",
" self.paths, self.labels = read_filepaths(trainfile)\n",
" elif (mode == 'test'):\n",
" self.paths, self.labels = read_filepaths(testfile)\n",
" print(\"{} examples = {}\".format(mode,len(self.paths)))\n",
" self.mode = mode\n",
"\n",
" def __len__(self):\n",
" return len(self.paths)\n",
"\n",
" def __getitem__(self, index):\n",
"\n",
" image_tensor = self.load_image(self.root+self.paths[index], self.dim, augmentation=self.mode)\n",
" label_tensor = torch.tensor(self.COVIDxDICT[self.labels[index]],dtype=torch.long)\n",
"\n",
" return image_tensor,label_tensor\n",
"\n",
" def load_image(self, img_path, dim, augmentation='test'):\n",
" if not os.path.exists(img_path):\n",
" print(\"IMAGE DOES NOT EXIST {}\".format(img_path))\n",
" image = Image.open(img_path).convert('RGB') \n",
" image = image.resize(dim).convert('RGB') \n",
" \n",
" #image.convert('RGB')\n",
" t = transforms.ToTensor()\n",
" # print(t(image).shape)\n",
" normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225])\n",
" norm = transforms.Normalize(mean=[0.5, 0.5,0.5 ],\n",
" std=[1, 1, 1])\n",
"\n",
" image_tensor = normalize(t(image))\n",
" \n",
" # if(image_tensor.size(0)>1):\n",
" # #print(img_path,\" > 1 channels\")\n",
" # image_tensor = image_tensor.mean(dim=0,keepdim=True)\n",
" return image_tensor\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ItoCpmm4vf0f",
"colab_type": "text"
},
"source": [
"# Trainer functions"
]
},
{
"cell_type": "code",
"metadata": {
"id": "PRkbalc4vhsf",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"def initialize(args):\n",
" if args.device is not None:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(args.device)\n",
" model = select_model(args)\n",
" \n",
" optimizer = select_optimizer(args,model)\n",
" if (args.cuda):\n",
" model.cuda()\n",
"\n",
" train_params = {'batch_size': args.batch_size,\n",
" 'shuffle': True,\n",
" 'num_workers': 2}\n",
"\n",
" test_params = {'batch_size': args.batch_size,\n",
" 'shuffle': False,\n",
" 'num_workers': 1}\n",
"\n",
" train_loader = COVIDxDataset(mode='train', n_classes=args.classes, dataset_path=args.dataset,\n",
" dim=(224, 224))\n",
" val_loader = COVIDxDataset(mode='test', n_classes=args.classes, dataset_path=args.dataset,\n",
" dim=(224, 224))\n",
" training_generator = DataLoader(train_loader, **train_params)\n",
" val_generator = DataLoader(val_loader, **test_params)\n",
" return model, optimizer,training_generator,val_generator\n",
"\n",
"\n",
"def train(args, model, trainloader, optimizer, epoch):\n",
" model.train()\n",
" criterion = nn.CrossEntropyLoss(reduction='mean')\n",
"\n",
" metrics = Metrics('')\n",
" metrics.reset()\n",
" for batch_idx, input_tensors in enumerate(trainloader):\n",
" optimizer.zero_grad()\n",
" input_data, target = input_tensors\n",
" if (args.cuda):\n",
" input_data = input_data.cuda()\n",
" target = target.cuda()\n",
"\n",
" output = model(input_data)\n",
"\n",
" loss = focal_loss(output, target)\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" correct, total, acc = accuracy(output, target)\n",
"\n",
" num_samples = batch_idx * args.batch_size + 1\n",
" metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc})\n",
" print_stats(args, epoch, num_samples, trainloader, metrics)\n",
"\n",
" print_summary(args, epoch, num_samples, metrics, mode=\"Training\")\n",
" return metrics\n",
"\n",
"\n",
"def validation(args, model, testloader, epoch):\n",
" model.eval()\n",
" criterion = nn.CrossEntropyLoss(reduction='mean')\n",
"\n",
" metrics = Metrics('')\n",
" metrics.reset()\n",
" confusion_matrix = torch.zeros(args.classes, args.classes)\n",
" with torch.no_grad():\n",
" for batch_idx, input_tensors in enumerate(testloader):\n",
"\n",
" input_data, target = input_tensors\n",
" if (args.cuda):\n",
" input_data = input_data.cuda()\n",
" target = target.cuda()\n",
"\n",
" output = model(input_data)\n",
"\n",
" loss = focal_loss(output, target)\n",
"\n",
" correct, total, acc = accuracy(output, target)\n",
" num_samples = batch_idx * args.batch_size + 1\n",
" _, preds = torch.max(output, 1)\n",
" for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):\n",
" confusion_matrix[t.long(), p.long()] += 1\n",
" metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc})\n",
" #print_stats(args, epoch, num_samples, testloader, metrics)\n",
"\n",
" print_summary(args, epoch, num_samples, metrics, mode=\"Validation\")\n",
" return metrics,confusion_matrix\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "uz8BF21Lvl6a",
"colab_type": "text"
},
"source": [
"# MAIN"
]
},
{
"cell_type": "code",
"metadata": {
"id": "d9BF-bre0Y_M",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"\n",
"def main():\n",
"\n",
"\n",
"\n",
"\n",
" args = get_arguments()\n",
" SEED = args.seed\n",
" torch.manual_seed(SEED)\n",
" torch.backends.cudnn.deterministic = True\n",
" torch.backends.cudnn.benchmark = False\n",
" np.random.seed(SEED)\n",
" if(args.cuda):\n",
" torch.cuda.manual_seed(SEED)\n",
" model, optimizer,training_generator,val_generator = initialize(args)\n",
" \n",
" print(model)\n",
"\n",
" best_pred_loss = 1000.0\n",
" scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, min_lr=1e-5, verbose=True)\n",
" print('Checkpoint folder ',args.save)\n",
" #writer = SummaryWriter(log_dir='../runs/' + args.model, comment=args.model)\n",
" for epoch in range(1, args.nEpochs + 1):\n",
" train(args, model, training_generator, optimizer, epoch)\n",
" val_metrics,confusion_matrix = validation(args, model, val_generator, epoch)\n",
" #confusion_matrix = torch.tensor([0.0])\n",
" #val_metrics = Metrics('')\n",
" best_pred_loss = save_model(model,optimizer, args,val_metrics, epoch, best_pred_loss,confusion_matrix)\n",
" #print('avg lpss ' ,val_metrics.avg_loss())\n",
" print(confusion_matrix.cpu().numpy())\n",
" scheduler.step(val_metrics.avg_loss())\n",
" \n",
"\n",
"\n",
"def get_arguments():\n",
" parser = argparse.ArgumentParser()\n",
" parser.add_argument('--batch_size', type=int, default=2)\n",
" parser.add_argument('--log_interval', type=int, default=1000)\n",
" parser.add_argument('--dataset_name', type=str, default=\"COVIDx\")\n",
" parser.add_argument('--nEpochs', type=int, default=250)\n",
" parser.add_argument('--device', type=int, default=0)\n",
" parser.add_argument('--seed', type=int, default=123)\n",
" parser.add_argument('--classes', type=int, default=3)\n",
" parser.add_argument('--inChannels', type=int, default=1)\n",
" parser.add_argument('--lr', default=2e-5, type=float,\n",
" help='learning rate (default: 1e-3)')\n",
" parser.add_argument('--weight_decay', default=1e-7, type=float,\n",
" help='weight decay (default: 1e-6)')\n",
" parser.add_argument('--cuda', action='store_true', default=True)\n",
" parser.add_argument('--resume', default='', type=str, metavar='PATH',\n",
" help='path to latest checkpoint (default: none)')\n",
" parser.add_argument('--model', type=str, default='COVIDNet_large',\n",
" choices=('COVIDNET'))\n",
" parser.add_argument('--opt', type=str, default='adam',\n",
" choices=('sgd', 'adam', 'rmsprop'))\n",
" parser.add_argument('--dataset', type=str, default='/content/covid-chestxray-dataset/data',\n",
" help='path to dataset ')\n",
" parser.add_argument('--save', type=str, default='/content/drive/My Drive/MEDICAL/saved/COVIDNet_small'+datestr() ,\n",
" help='path to checkpoint ')\n",
" args = parser.parse_args([])\n",
" return args\n",
"\n",
"\n",
"if __name__ == '__main__':\n",
" main()\n"
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment