-
-
Save calebrob6/4d7fc311045c4f9015e401100e34ed38 to your computer and use it in GitHub Desktop.
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "6a5b65df", | |
"metadata": {}, | |
"source": [ | |
"# LEVIR-CD+ change detection example notebook\n", | |
"\n", | |
"We start off by installing torchgeo. If you are running this on Colab, then you will need to restart your runtime after this step." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4627b902", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!pip install torchgeo" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "475f3715", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"\n", | |
"import torchgeo\n", | |
"from torchgeo.datasets import LEVIRCDPlus\n", | |
"from torchgeo.datasets.utils import unbind_samples\n", | |
"from torchgeo.trainers import SemanticSegmentationTask\n", | |
"from torchgeo.datamodules.utils import dataset_split\n", | |
"\n", | |
"import lightning.pytorch as pl\n", | |
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", | |
"from lightning.pytorch import Trainer, seed_everything\n", | |
"from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger\n", | |
"from lightning.pytorch import LightningDataModule\n", | |
"\n", | |
"import torch\n", | |
"from torch.utils.data import DataLoader\n", | |
"import kornia.augmentation as K\n", | |
"\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import torchvision\n", | |
"from torchvision.transforms import Compose\n", | |
"from tqdm import tqdm\n", | |
"\n", | |
"from sklearn.metrics import precision_score, recall_score" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "2ae75c6f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('0.5.1', '2.1.3', '2.0.1+cu117')" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torchgeo.__version__, pl.__version__, torch.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "daedd8ce", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.cuda.is_available()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "0b012728", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# some experiment parameters\n", | |
"\n", | |
"experiment_name = \"experiment_test\"\n", | |
"experiment_dir = f\"results/{experiment_name}\"\n", | |
"os.makedirs(experiment_dir, exist_ok=True)\n", | |
"\n", | |
"batch_size = 8\n", | |
"learning_rate = 0.0001\n", | |
"gpu_id = 0\n", | |
"device = torch.device(f\"cuda:{gpu_id}\")\n", | |
"num_dataloader_workers = 2\n", | |
"patch_size = 256\n", | |
"val_split_pct = 0.1 # how much of our training set to hold out as a validation set" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "ca211445", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Files already downloaded and verified\n", | |
"Files already downloaded and verified\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(637, 348)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Download the dataset and see how many images are in the train and test splits\n", | |
"\n", | |
"train_dataset = LEVIRCDPlus(root=\"data/LEVIRCDPlus\", split=\"train\", download=True, checksum=True)\n", | |
"test_dataset = LEVIRCDPlus(root=\"data/LEVIRCDPlus\", split=\"test\", download=True, checksum=True)\n", | |
"len(train_dataset), len(test_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8d7e6981", | |
"metadata": {}, | |
"source": [ | |
"## Excersise 1\n", | |
"\n", | |
"Plot some examples from the `train_dataset` (note: torchgeo will help you out here)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8127d129", | |
"metadata": {}, | |
"source": [ | |
"## Define a PyTorch Lightning module and datamodule\n", | |
"\n", | |
"PyTorch Lightning organizes the steps required for training deep learning models in `LightningModules`, and organizes the dataset handling to creating dataloaders in `LightningDataModules`. TorchGeo provides pre-built LightningDataModules for a handful of datasets, and pre-built \"trainers\" (i.e. LightningModules) for a variety of different types of tasks.\n", | |
"\n", | |
"For this tutorial, we will lightly extend TorchGeo's `SemanticSegmentationTask` (just to add some custom plotting code) and create a new LightningDataModule for the LEVIR-CD+ dataset." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "26f62ac5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CustomSemanticSegmentationTask(SemanticSegmentationTask):\n", | |
" \n", | |
" def plot(self, sample):\n", | |
" image1 = sample[\"image\"][:3]\n", | |
" image2 = sample[\"image\"][3:]\n", | |
" mask = sample[\"mask\"]\n", | |
" prediction = sample[\"prediction\"]\n", | |
"\n", | |
" fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(4 * 5, 5))\n", | |
" axs[0].imshow(image1.permute(1, 2, 0))\n", | |
" axs[0].axis(\"off\")\n", | |
" axs[1].imshow(image2.permute(1, 2, 0))\n", | |
" axs[1].axis(\"off\")\n", | |
" axs[2].imshow(mask)\n", | |
" axs[2].axis(\"off\")\n", | |
" axs[3].imshow(prediction)\n", | |
" axs[3].axis(\"off\")\n", | |
"\n", | |
" axs[0].set_title(\"Image 1\")\n", | |
" axs[1].set_title(\"Image 2\")\n", | |
" axs[2].set_title(\"Mask\")\n", | |
" axs[3].set_title(\"Prediction\")\n", | |
"\n", | |
" plt.tight_layout()\n", | |
" \n", | |
" return fig\n", | |
"\n", | |
" # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function\n", | |
" def training_step(self, *args, **kwargs):\n", | |
" batch = args[0]\n", | |
" batch_idx = args[1]\n", | |
" \n", | |
" x = batch[\"image\"]\n", | |
" y = batch[\"mask\"]\n", | |
" y_hat = self.forward(x)\n", | |
" y_hat_hard = y_hat.argmax(dim=1)\n", | |
"\n", | |
" loss = self.criterion(y_hat, y)\n", | |
"\n", | |
" self.log(\"train_loss\", loss, on_step=True, on_epoch=False)\n", | |
" self.train_metrics(y_hat_hard, y)\n", | |
"\n", | |
" if batch_idx < 10:\n", | |
" batch[\"prediction\"] = y_hat_hard\n", | |
" for key in [\"image\", \"mask\", \"prediction\"]:\n", | |
" batch[key] = batch[key].cpu()\n", | |
" sample = unbind_samples(batch)[0]\n", | |
" fig = self.plot(sample)\n", | |
" summary_writer = self.logger.experiment\n", | |
" summary_writer.add_figure(\n", | |
" f\"image/train/{batch_idx}\", fig, global_step=self.global_step\n", | |
" )\n", | |
" plt.close()\n", | |
" \n", | |
" return loss\n", | |
" \n", | |
" # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function\n", | |
" def validation_step(self, *args, **kwargs):\n", | |
" batch = args[0]\n", | |
" batch_idx = args[1]\n", | |
" x = batch[\"image\"]\n", | |
" y = batch[\"mask\"]\n", | |
" y_hat = self.forward(x)\n", | |
" y_hat_hard = y_hat.argmax(dim=1)\n", | |
"\n", | |
" loss = self.criterion(y_hat, y)\n", | |
"\n", | |
" self.log(\"val_loss\", loss, on_step=False, on_epoch=True)\n", | |
" self.val_metrics(y_hat_hard, y)\n", | |
"\n", | |
" if batch_idx < 10:\n", | |
" batch[\"prediction\"] = y_hat_hard\n", | |
" for key in [\"image\", \"mask\", \"prediction\"]:\n", | |
" batch[key] = batch[key].cpu()\n", | |
" sample = unbind_samples(batch)[0]\n", | |
" fig = self.plot(sample)\n", | |
" summary_writer = self.logger.experiment\n", | |
" summary_writer.add_figure(\n", | |
" f\"image/val/{batch_idx}\", fig, global_step=self.global_step\n", | |
" )\n", | |
" plt.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "f420887f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class LEVIRCDPlusDataModule(pl.LightningDataModule):\n", | |
"\n", | |
" def __init__(\n", | |
" self,\n", | |
" batch_size=32,\n", | |
" num_workers=0,\n", | |
" val_split_pct=0.2,\n", | |
" patch_size=(256, 256),\n", | |
" **kwargs,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" self.batch_size = batch_size\n", | |
" self.num_workers = num_workers\n", | |
" self.val_split_pct = val_split_pct\n", | |
" self.patch_size = patch_size\n", | |
" self.kwargs = kwargs\n", | |
"\n", | |
" def on_after_batch_transfer(\n", | |
" self, batch, batch_idx\n", | |
" ):\n", | |
" if (\n", | |
" hasattr(self, \"trainer\")\n", | |
" and self.trainer is not None\n", | |
" and hasattr(self.trainer, \"training\")\n", | |
" and self.trainer.training\n", | |
" ):\n", | |
" # Kornia expects masks to be floats with a channel dimension\n", | |
" x = batch[\"image\"]\n", | |
" y = batch[\"mask\"].float().unsqueeze(1)\n", | |
"\n", | |
" train_augmentations = K.AugmentationSequential(\n", | |
" K.RandomRotation(p=0.5, degrees=90),\n", | |
" K.RandomHorizontalFlip(p=0.5),\n", | |
" K.RandomVerticalFlip(p=0.5),\n", | |
" K.RandomCrop(self.patch_size),\n", | |
" K.RandomSharpness(p=0.5),\n", | |
" data_keys=[\"input\", \"mask\"],\n", | |
" )\n", | |
" x, y = train_augmentations(x, y)\n", | |
"\n", | |
" # torchmetrics expects masks to be longs without a channel dimension\n", | |
" batch[\"image\"] = x\n", | |
" batch[\"mask\"] = y.squeeze(1).long()\n", | |
"\n", | |
" return batch\n", | |
" \n", | |
" def preprocess(self, sample):\n", | |
" sample[\"image\"] = (sample[\"image\"] / 255.0).float()\n", | |
" sample[\"image\"] = torch.flatten(sample[\"image\"], 0, 1)\n", | |
" sample[\"mask\"] = sample[\"mask\"].long()\n", | |
" return sample\n", | |
"\n", | |
" def prepare_data(self):\n", | |
" LEVIRCDPlus(split=\"train\", **self.kwargs)\n", | |
"\n", | |
" def setup(self, stage=None):\n", | |
" train_transforms = Compose([self.preprocess])\n", | |
" test_transforms = Compose([self.preprocess])\n", | |
"\n", | |
" train_dataset = LEVIRCDPlus(\n", | |
" split=\"train\", transforms=train_transforms, **self.kwargs\n", | |
" )\n", | |
"\n", | |
" if self.val_split_pct > 0.0:\n", | |
" self.train_dataset, self.val_dataset, _ = dataset_split(\n", | |
" train_dataset, val_pct=self.val_split_pct, test_pct=0.0\n", | |
" )\n", | |
" else:\n", | |
" self.train_dataset = train_dataset\n", | |
" self.val_dataset = train_dataset\n", | |
"\n", | |
" self.test_dataset = LEVIRCDPlus(\n", | |
" split=\"test\", transforms=test_transforms, **self.kwargs\n", | |
" )\n", | |
"\n", | |
" def train_dataloader(self):\n", | |
" return DataLoader(\n", | |
" self.train_dataset,\n", | |
" batch_size=self.batch_size,\n", | |
" num_workers=self.num_workers,\n", | |
" shuffle=True,\n", | |
" )\n", | |
"\n", | |
" def val_dataloader(self):\n", | |
" return DataLoader(\n", | |
" self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False\n", | |
" )\n", | |
"\n", | |
" def test_dataloader(self):\n", | |
" return DataLoader(\n", | |
" self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d221e5db", | |
"metadata": {}, | |
"source": [ | |
"## Setting up a training run" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "97a5ff80", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"datamodule = LEVIRCDPlusDataModule(\n", | |
" root=\"data/LEVIRCDPlus\",\n", | |
" batch_size=batch_size,\n", | |
" num_workers=num_dataloader_workers,\n", | |
" val_split_pct=val_split_pct,\n", | |
" patch_size=(patch_size, patch_size),\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "82b472f5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"task = CustomSemanticSegmentationTask(\n", | |
" model=\"unet\",\n", | |
" backbone=\"resnet18\",\n", | |
" weights=True,\n", | |
" in_channels=6,\n", | |
" num_classes=2,\n", | |
" loss=\"ce\",\n", | |
" ignore_index=None,\n", | |
" lr=learning_rate,\n", | |
" patience=10\n", | |
")\n", | |
"\n", | |
"checkpoint_callback = ModelCheckpoint(\n", | |
" monitor=\"val_loss\",\n", | |
" dirpath=experiment_dir,\n", | |
" save_top_k=1,\n", | |
" save_last=True,\n", | |
")\n", | |
"\n", | |
"early_stopping_callback = EarlyStopping(\n", | |
" monitor=\"val_loss\",\n", | |
" min_delta=0.00,\n", | |
" patience=10,\n", | |
")\n", | |
"\n", | |
"tb_logger = TensorBoardLogger(\n", | |
" save_dir=\"logs/\",\n", | |
" name=experiment_name\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e54642fd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%load_ext tensorboard" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "94fe9c6d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%tensorboard --logdir logs/" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "6fc5259c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"GPU available: True (cuda), used: True\n", | |
"TPU available: False, using: 0 TPU cores\n", | |
"IPU available: False, using: 0 IPUs\n", | |
"HPU available: False, using: 0 HPUs\n", | |
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", | |
"\n", | |
" | Name | Type | Params\n", | |
"---------------------------------------------------\n", | |
"0 | model | Unet | 14.3 M\n", | |
"1 | loss | CrossEntropyLoss | 0 \n", | |
"2 | train_metrics | MetricCollection | 0 \n", | |
"3 | val_metrics | MetricCollection | 0 \n", | |
"4 | test_metrics | MetricCollection | 0 \n", | |
"---------------------------------------------------\n", | |
"14.3 M Trainable params\n", | |
"0 Non-trainable params\n", | |
"14.3 M Total params\n", | |
"57.351 Total estimated model params size (MB)\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Sanity Checking: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/calebrobinson/.conda/envs/test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (18) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", | |
" rank_zero_warn(\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a15abdbc468b44d0b1a43a18e285ae95", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Training: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f2c8514f2d4a41e09b3bc88af1b40887", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Validation: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"trainer = pl.Trainer(\n", | |
" callbacks=[checkpoint_callback, early_stopping_callback],\n", | |
" logger=[tb_logger],\n", | |
" default_root_dir=experiment_dir,\n", | |
" min_epochs=10,\n", | |
" max_epochs=200,\n", | |
" accelerator='gpu',\n", | |
" devices=[gpu_id]\n", | |
")\n", | |
"\n", | |
"_ = trainer.fit(model=task, datamodule=datamodule)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2cfacd81", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trainer.test(model=task, datamodule=datamodule)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "346e4afe", | |
"metadata": {}, | |
"source": [ | |
"## Custom test step to compute the precision, recall, and F1 metrics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "b61db9fb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Example of how to load a trained task from a checkpoint file\n", | |
"# task = CustomSemanticSegmentationTask.load_from_checkpoint(\"results/...\")\n", | |
"# datamodule.setup(\"test\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "c9b7a93c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = task.model.to(device).eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "0e545e06", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 44/44 [00:21<00:00, 2.04it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"y_preds = []\n", | |
"y_trues = []\n", | |
"for batch in tqdm(datamodule.test_dataloader()):\n", | |
" images = batch[\"image\"].to(device)\n", | |
" y_trues.append(batch[\"mask\"].numpy().ravel()[::500])\n", | |
" with torch.inference_mode():\n", | |
" y_pred = model(images).argmax(dim=1).cpu().numpy().ravel()[::500]\n", | |
" y_preds.append(y_pred)\n", | |
"\n", | |
"y_preds = np.concatenate(y_preds)\n", | |
"y_trues = np.concatenate(y_trues)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "8b5a6975", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"precision = precision_score(y_trues, y_preds)\n", | |
"recall = recall_score(y_trues, y_preds)\n", | |
"f1 = 2 * (precision * recall) / (precision + recall)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "bf25b1d4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.7234695667426767, 0.5552638664512655, 0.6283037550460812)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"precision, recall, f1" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "python3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Then on running trainer.fit
you will get
TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `CustomSemanticSegmentationTask`
It is necessary to use the new lightning format for imports:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch import LightningDataModule
You will next get:
AttributeError: 'CustomSemanticSegmentationTask' object has no attribute 'loss'
It is necessary in the custom trainer to use:
loss: Tensor = self.criterion(y_hat, y)
| Name | Type | Params
0 | criterion | CrossEntropyLoss | 0
1 | train_metrics | MetricCollection | 0
2 | val_metrics | MetricCollection | 0
3 | test_metrics | MetricCollection | 0
4 | model | Unet | 14.3 M
14.3 M Trainable params
0 Non-trainable params
14.3 M Total params
57.351 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
It remains at this stage, what could be the problem?
Hey @mustafaemre2 -- are you running on the GPU?
Updated with @robmarkcole's fixes (and ensured that the notebook runs end-to-end) for torchgeo 0.5.1 (thanks Robin!)
Yes, I used your codes exactly
My GPU's RTX 3060 laptop @calebrob6
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set torch.set_float32_matmul_precision('medium' | 'high')
which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
i get this message @calebrob6
Download the dataset and see how many images are in the train and test splits
train_dataset = LEVIRCDPlus(root="data/LEVIRCDPlus", split="train", download=True, checksum=True)
test_dataset = LEVIRCDPlus(root="data/LEVIRCDPlus", split="test", download=True, checksum=True)
len(train_dataset), len(test_dataset)
Its give error :-
RuntimeError: The MD5 checksum of the download file data/LEVIRCDPlus/LEVIR-CD+.zip does not match the one on record.Please delete the file and try again. If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues.
i get this msg @calebrob6
@nadeem-git-coder Have you found any solution for that?
@ProtikBose I have downloaded the dataset mannually and use it .
Have you encountered the error?
What the error
Running with torchgeo 0.5.0 will give:
The necessary update: