Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Last active March 7, 2024 00:20
Show Gist options
  • Save calebrob6/4d7fc311045c4f9015e401100e34ed38 to your computer and use it in GitHub Desktop.
Save calebrob6/4d7fc311045c4f9015e401100e34ed38 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "6a5b65df",
"metadata": {},
"source": [
"# LEVIR-CD+ change detection example notebook\n",
"\n",
"We start off by installing torchgeo. If you are running this on Colab, then you will need to restart your runtime after this step."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4627b902",
"metadata": {},
"outputs": [],
"source": [
"!pip install torchgeo"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "475f3715",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torchgeo\n",
"from torchgeo.datasets import LEVIRCDPlus\n",
"from torchgeo.datasets.utils import unbind_samples\n",
"from torchgeo.trainers import SemanticSegmentationTask\n",
"from torchgeo.datamodules.utils import dataset_split\n",
"\n",
"import lightning.pytorch as pl\n",
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
"from lightning.pytorch import Trainer, seed_everything\n",
"from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger\n",
"from lightning.pytorch import LightningDataModule\n",
"\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import kornia.augmentation as K\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torchvision\n",
"from torchvision.transforms import Compose\n",
"from tqdm import tqdm\n",
"\n",
"from sklearn.metrics import precision_score, recall_score"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2ae75c6f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('0.5.1', '2.1.3', '2.0.1+cu117')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torchgeo.__version__, pl.__version__, torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "daedd8ce",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0b012728",
"metadata": {},
"outputs": [],
"source": [
"# some experiment parameters\n",
"\n",
"experiment_name = \"experiment_test\"\n",
"experiment_dir = f\"results/{experiment_name}\"\n",
"os.makedirs(experiment_dir, exist_ok=True)\n",
"\n",
"batch_size = 8\n",
"learning_rate = 0.0001\n",
"gpu_id = 0\n",
"device = torch.device(f\"cuda:{gpu_id}\")\n",
"num_dataloader_workers = 2\n",
"patch_size = 256\n",
"val_split_pct = 0.1 # how much of our training set to hold out as a validation set"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ca211445",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
},
{
"data": {
"text/plain": [
"(637, 348)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Download the dataset and see how many images are in the train and test splits\n",
"\n",
"train_dataset = LEVIRCDPlus(root=\"data/LEVIRCDPlus\", split=\"train\", download=True, checksum=True)\n",
"test_dataset = LEVIRCDPlus(root=\"data/LEVIRCDPlus\", split=\"test\", download=True, checksum=True)\n",
"len(train_dataset), len(test_dataset)"
]
},
{
"cell_type": "markdown",
"id": "8d7e6981",
"metadata": {},
"source": [
"## Excersise 1\n",
"\n",
"Plot some examples from the `train_dataset` (note: torchgeo will help you out here)."
]
},
{
"cell_type": "markdown",
"id": "8127d129",
"metadata": {},
"source": [
"## Define a PyTorch Lightning module and datamodule\n",
"\n",
"PyTorch Lightning organizes the steps required for training deep learning models in `LightningModules`, and organizes the dataset handling to creating dataloaders in `LightningDataModules`. TorchGeo provides pre-built LightningDataModules for a handful of datasets, and pre-built \"trainers\" (i.e. LightningModules) for a variety of different types of tasks.\n",
"\n",
"For this tutorial, we will lightly extend TorchGeo's `SemanticSegmentationTask` (just to add some custom plotting code) and create a new LightningDataModule for the LEVIR-CD+ dataset."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "26f62ac5",
"metadata": {},
"outputs": [],
"source": [
"class CustomSemanticSegmentationTask(SemanticSegmentationTask):\n",
" \n",
" def plot(self, sample):\n",
" image1 = sample[\"image\"][:3]\n",
" image2 = sample[\"image\"][3:]\n",
" mask = sample[\"mask\"]\n",
" prediction = sample[\"prediction\"]\n",
"\n",
" fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(4 * 5, 5))\n",
" axs[0].imshow(image1.permute(1, 2, 0))\n",
" axs[0].axis(\"off\")\n",
" axs[1].imshow(image2.permute(1, 2, 0))\n",
" axs[1].axis(\"off\")\n",
" axs[2].imshow(mask)\n",
" axs[2].axis(\"off\")\n",
" axs[3].imshow(prediction)\n",
" axs[3].axis(\"off\")\n",
"\n",
" axs[0].set_title(\"Image 1\")\n",
" axs[1].set_title(\"Image 2\")\n",
" axs[2].set_title(\"Mask\")\n",
" axs[3].set_title(\"Prediction\")\n",
"\n",
" plt.tight_layout()\n",
" \n",
" return fig\n",
"\n",
" # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function\n",
" def training_step(self, *args, **kwargs):\n",
" batch = args[0]\n",
" batch_idx = args[1]\n",
" \n",
" x = batch[\"image\"]\n",
" y = batch[\"mask\"]\n",
" y_hat = self.forward(x)\n",
" y_hat_hard = y_hat.argmax(dim=1)\n",
"\n",
" loss = self.criterion(y_hat, y)\n",
"\n",
" self.log(\"train_loss\", loss, on_step=True, on_epoch=False)\n",
" self.train_metrics(y_hat_hard, y)\n",
"\n",
" if batch_idx < 10:\n",
" batch[\"prediction\"] = y_hat_hard\n",
" for key in [\"image\", \"mask\", \"prediction\"]:\n",
" batch[key] = batch[key].cpu()\n",
" sample = unbind_samples(batch)[0]\n",
" fig = self.plot(sample)\n",
" summary_writer = self.logger.experiment\n",
" summary_writer.add_figure(\n",
" f\"image/train/{batch_idx}\", fig, global_step=self.global_step\n",
" )\n",
" plt.close()\n",
" \n",
" return loss\n",
" \n",
" # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function\n",
" def validation_step(self, *args, **kwargs):\n",
" batch = args[0]\n",
" batch_idx = args[1]\n",
" x = batch[\"image\"]\n",
" y = batch[\"mask\"]\n",
" y_hat = self.forward(x)\n",
" y_hat_hard = y_hat.argmax(dim=1)\n",
"\n",
" loss = self.criterion(y_hat, y)\n",
"\n",
" self.log(\"val_loss\", loss, on_step=False, on_epoch=True)\n",
" self.val_metrics(y_hat_hard, y)\n",
"\n",
" if batch_idx < 10:\n",
" batch[\"prediction\"] = y_hat_hard\n",
" for key in [\"image\", \"mask\", \"prediction\"]:\n",
" batch[key] = batch[key].cpu()\n",
" sample = unbind_samples(batch)[0]\n",
" fig = self.plot(sample)\n",
" summary_writer = self.logger.experiment\n",
" summary_writer.add_figure(\n",
" f\"image/val/{batch_idx}\", fig, global_step=self.global_step\n",
" )\n",
" plt.close()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f420887f",
"metadata": {},
"outputs": [],
"source": [
"class LEVIRCDPlusDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(\n",
" self,\n",
" batch_size=32,\n",
" num_workers=0,\n",
" val_split_pct=0.2,\n",
" patch_size=(256, 256),\n",
" **kwargs,\n",
" ):\n",
" super().__init__()\n",
" self.batch_size = batch_size\n",
" self.num_workers = num_workers\n",
" self.val_split_pct = val_split_pct\n",
" self.patch_size = patch_size\n",
" self.kwargs = kwargs\n",
"\n",
" def on_after_batch_transfer(\n",
" self, batch, batch_idx\n",
" ):\n",
" if (\n",
" hasattr(self, \"trainer\")\n",
" and self.trainer is not None\n",
" and hasattr(self.trainer, \"training\")\n",
" and self.trainer.training\n",
" ):\n",
" # Kornia expects masks to be floats with a channel dimension\n",
" x = batch[\"image\"]\n",
" y = batch[\"mask\"].float().unsqueeze(1)\n",
"\n",
" train_augmentations = K.AugmentationSequential(\n",
" K.RandomRotation(p=0.5, degrees=90),\n",
" K.RandomHorizontalFlip(p=0.5),\n",
" K.RandomVerticalFlip(p=0.5),\n",
" K.RandomCrop(self.patch_size),\n",
" K.RandomSharpness(p=0.5),\n",
" data_keys=[\"input\", \"mask\"],\n",
" )\n",
" x, y = train_augmentations(x, y)\n",
"\n",
" # torchmetrics expects masks to be longs without a channel dimension\n",
" batch[\"image\"] = x\n",
" batch[\"mask\"] = y.squeeze(1).long()\n",
"\n",
" return batch\n",
" \n",
" def preprocess(self, sample):\n",
" sample[\"image\"] = (sample[\"image\"] / 255.0).float()\n",
" sample[\"image\"] = torch.flatten(sample[\"image\"], 0, 1)\n",
" sample[\"mask\"] = sample[\"mask\"].long()\n",
" return sample\n",
"\n",
" def prepare_data(self):\n",
" LEVIRCDPlus(split=\"train\", **self.kwargs)\n",
"\n",
" def setup(self, stage=None):\n",
" train_transforms = Compose([self.preprocess])\n",
" test_transforms = Compose([self.preprocess])\n",
"\n",
" train_dataset = LEVIRCDPlus(\n",
" split=\"train\", transforms=train_transforms, **self.kwargs\n",
" )\n",
"\n",
" if self.val_split_pct > 0.0:\n",
" self.train_dataset, self.val_dataset, _ = dataset_split(\n",
" train_dataset, val_pct=self.val_split_pct, test_pct=0.0\n",
" )\n",
" else:\n",
" self.train_dataset = train_dataset\n",
" self.val_dataset = train_dataset\n",
"\n",
" self.test_dataset = LEVIRCDPlus(\n",
" split=\"test\", transforms=test_transforms, **self.kwargs\n",
" )\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(\n",
" self.train_dataset,\n",
" batch_size=self.batch_size,\n",
" num_workers=self.num_workers,\n",
" shuffle=True,\n",
" )\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(\n",
" self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False\n",
" )\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(\n",
" self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "d221e5db",
"metadata": {},
"source": [
"## Setting up a training run"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "97a5ff80",
"metadata": {},
"outputs": [],
"source": [
"datamodule = LEVIRCDPlusDataModule(\n",
" root=\"data/LEVIRCDPlus\",\n",
" batch_size=batch_size,\n",
" num_workers=num_dataloader_workers,\n",
" val_split_pct=val_split_pct,\n",
" patch_size=(patch_size, patch_size),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "82b472f5",
"metadata": {},
"outputs": [],
"source": [
"task = CustomSemanticSegmentationTask(\n",
" model=\"unet\",\n",
" backbone=\"resnet18\",\n",
" weights=True,\n",
" in_channels=6,\n",
" num_classes=2,\n",
" loss=\"ce\",\n",
" ignore_index=None,\n",
" lr=learning_rate,\n",
" patience=10\n",
")\n",
"\n",
"checkpoint_callback = ModelCheckpoint(\n",
" monitor=\"val_loss\",\n",
" dirpath=experiment_dir,\n",
" save_top_k=1,\n",
" save_last=True,\n",
")\n",
"\n",
"early_stopping_callback = EarlyStopping(\n",
" monitor=\"val_loss\",\n",
" min_delta=0.00,\n",
" patience=10,\n",
")\n",
"\n",
"tb_logger = TensorBoardLogger(\n",
" save_dir=\"logs/\",\n",
" name=experiment_name\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e54642fd",
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94fe9c6d",
"metadata": {},
"outputs": [],
"source": [
"%tensorboard --logdir logs/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fc5259c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
"\n",
" | Name | Type | Params\n",
"---------------------------------------------------\n",
"0 | model | Unet | 14.3 M\n",
"1 | loss | CrossEntropyLoss | 0 \n",
"2 | train_metrics | MetricCollection | 0 \n",
"3 | val_metrics | MetricCollection | 0 \n",
"4 | test_metrics | MetricCollection | 0 \n",
"---------------------------------------------------\n",
"14.3 M Trainable params\n",
"0 Non-trainable params\n",
"14.3 M Total params\n",
"57.351 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/calebrobinson/.conda/envs/test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (18) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
" rank_zero_warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a15abdbc468b44d0b1a43a18e285ae95",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f2c8514f2d4a41e09b3bc88af1b40887",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer = pl.Trainer(\n",
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
" logger=[tb_logger],\n",
" default_root_dir=experiment_dir,\n",
" min_epochs=10,\n",
" max_epochs=200,\n",
" accelerator='gpu',\n",
" devices=[gpu_id]\n",
")\n",
"\n",
"_ = trainer.fit(model=task, datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2cfacd81",
"metadata": {},
"outputs": [],
"source": [
"trainer.test(model=task, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"id": "346e4afe",
"metadata": {},
"source": [
"## Custom test step to compute the precision, recall, and F1 metrics"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b61db9fb",
"metadata": {},
"outputs": [],
"source": [
"# Example of how to load a trained task from a checkpoint file\n",
"# task = CustomSemanticSegmentationTask.load_from_checkpoint(\"results/...\")\n",
"# datamodule.setup(\"test\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c9b7a93c",
"metadata": {},
"outputs": [],
"source": [
"model = task.model.to(device).eval()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0e545e06",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 44/44 [00:21<00:00, 2.04it/s]\n"
]
}
],
"source": [
"y_preds = []\n",
"y_trues = []\n",
"for batch in tqdm(datamodule.test_dataloader()):\n",
" images = batch[\"image\"].to(device)\n",
" y_trues.append(batch[\"mask\"].numpy().ravel()[::500])\n",
" with torch.inference_mode():\n",
" y_pred = model(images).argmax(dim=1).cpu().numpy().ravel()[::500]\n",
" y_preds.append(y_pred)\n",
"\n",
"y_preds = np.concatenate(y_preds)\n",
"y_trues = np.concatenate(y_trues)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8b5a6975",
"metadata": {},
"outputs": [],
"source": [
"precision = precision_score(y_trues, y_preds)\n",
"recall = recall_score(y_trues, y_preds)\n",
"f1 = 2 * (precision * recall) / (precision + recall)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "bf25b1d4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.7234695667426767, 0.5552638664512655, 0.6283037550460812)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"precision, recall, f1"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@robmarkcole
Copy link

Running with torchgeo 0.5.0 will give:

TypeError: SemanticSegmentationTask.__init__() got an unexpected keyword argument 'segmentation_model'

The necessary update:

task = CustomSemanticSegmentationTask(
    model="unet",
    backbone="resnet18",
    weights=True,
    in_channels=6,
    num_classes=2,
    loss="ce",
    ignore_index=None,
    lr=learning_rate,
    patience=10
)

@robmarkcole
Copy link

Then on running trainer.fit you will get

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `CustomSemanticSegmentationTask`

It is necessary to use the new lightning format for imports:

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch import LightningDataModule

@robmarkcole
Copy link

You will next get:

AttributeError: 'CustomSemanticSegmentationTask' object has no attribute 'loss'

It is necessary in the custom trainer to use:

    loss: Tensor = self.criterion(y_hat, y)

@mustafaemre2
Copy link

| Name | Type | Params

0 | criterion | CrossEntropyLoss | 0
1 | train_metrics | MetricCollection | 0
2 | val_metrics | MetricCollection | 0
3 | test_metrics | MetricCollection | 0
4 | model | Unet | 14.3 M

14.3 M Trainable params
0 Non-trainable params
14.3 M Total params
57.351 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]

It remains at this stage, what could be the problem?

@calebrob6
Copy link
Author

Hey @mustafaemre2 -- are you running on the GPU?

@calebrob6
Copy link
Author

calebrob6 commented Jan 11, 2024

Updated with @robmarkcole's fixes (and ensured that the notebook runs end-to-end) for torchgeo 0.5.1 (thanks Robin!)

@mustafaemre2
Copy link

Yes, I used your codes exactly
My GPU's RTX 3060 laptop @calebrob6

@mustafaemre2
Copy link

You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set torch.set_float32_matmul_precision('medium' | 'high') which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

i get this message @calebrob6

@nadeem-git-coder
Copy link

nadeem-git-coder commented Feb 27, 2024

Download the dataset and see how many images are in the train and test splits

train_dataset = LEVIRCDPlus(root="data/LEVIRCDPlus", split="train", download=True, checksum=True)
test_dataset = LEVIRCDPlus(root="data/LEVIRCDPlus", split="test", download=True, checksum=True)
len(train_dataset), len(test_dataset)

Its give error :-
RuntimeError: The MD5 checksum of the download file data/LEVIRCDPlus/LEVIR-CD+.zip does not match the one on record.Please delete the file and try again. If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues.

i get this msg @calebrob6

@ProtikBose
Copy link

@nadeem-git-coder Have you found any solution for that?

@nadeem-git-coder
Copy link

nadeem-git-coder commented Mar 7, 2024

@ProtikBose I have downloaded the dataset mannually and use it .
Have you encountered the error?
What the error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment