Skip to content

Instantly share code, notes, and snippets.

@TheDevPanda
Created April 16, 2022 10:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TheDevPanda/14d9e1ad53ec54490a22bb397dd08206 to your computer and use it in GitHub Desktop.
Save TheDevPanda/14d9e1ad53ec54490a22bb397dd08206 to your computer and use it in GitHub Desktop.
Convolutional neural network for D1C3 recognition
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "2bc9456b-8eef-4dfc-a838-7129f5d992a0",
"metadata": {},
"source": [
"Links/Inspiration:\n",
"- https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py\n",
"- https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html\n",
"- https://github.com/mlugs/machine-learning-workshop/blob/main/part2/blurred-voted.ipynb"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5bf5d4ae-eeab-4685-9a66-7dc8e33f74d3",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as T\n",
"\n",
"from torchmetrics import Accuracy\n",
"from torch.nn import functional as F\n",
"\n",
"import pytorch_lightning as lightning\n",
"lightning.seed_everything(42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bbc63b6-acd3-4e53-aea0-5dca6532cb67",
"metadata": {},
"outputs": [],
"source": [
"train_ds = torchvision.datasets.ImageFolder(root='data/D1C3/train',\n",
" transform=T.Compose([T.ToTensor(),\n",
" T.Resize(64),\n",
" T.CenterCrop(64),\n",
" T.functional.autocontrast,\n",
" T.Grayscale(),\n",
" T.RandomAffine(degrees=(0, 360), translate=(0, 0.15), scale=(0.95, 1))]))\n",
" #T.RandomRotation(degrees=(0, 180))]))\n",
" #T.Normalize(1, 1)]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1ce4da4-4d50-4ec8-a3b5-2faf33c9f76f",
"metadata": {},
"outputs": [],
"source": [
"train_size = int(len(train_ds) * 0.7)\n",
"val_size = len(train_ds) - train_size\n",
"print('{} train images'.format(len(train_ds)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ff811c3-ad78-455f-8893-7841f80d5763",
"metadata": {},
"outputs": [],
"source": [
"test_ds = torchvision.datasets.ImageFolder(root='data/D1C3/test',\n",
" transform=T.Compose([T.ToTensor(),\n",
" T.Resize(64),\n",
" T.CenterCrop(64),\n",
" T.functional.autocontrast,\n",
" T.Grayscale()]))\n",
" #T.Normalize(1, 1)]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95806693-bf24-4ece-8ee8-4aa7ef5b1e90",
"metadata": {},
"outputs": [],
"source": [
"print('{} test images'.format(len(test_ds)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f1fe6ef-2d43-4517-b3a8-d02913dfc3cd",
"metadata": {},
"outputs": [],
"source": [
"train_ds, val_ds = torch.utils.data.random_split(train_ds, [train_size, val_size])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6f2f1f4-a754-419d-ba99-e3cd1a687d07",
"metadata": {},
"outputs": [],
"source": [
"np.bincount([target for _, target in train_ds])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0133a78d-8f6e-48ab-8354-e6c85eb36724",
"metadata": {},
"outputs": [],
"source": [
"train_dataloader = torch.utils.data.DataLoader(train_ds,\n",
" batch_size=32,\n",
" shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a795f006-c44b-46c8-8e79-edba22beccfb",
"metadata": {},
"outputs": [],
"source": [
"images, labels = next(iter(train_dataloader))\n",
"grid = torchvision.utils.make_grid(images)\n",
"grid.shape\n",
"plt.imshow(grid.numpy().transpose((1, 2, 0)), cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3d92132-5233-4d2f-8eb3-16ee4a14dc0f",
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"wandb.login()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a95042d-e8b4-4d80-9ae1-7377fe622e8a",
"metadata": {},
"outputs": [],
"source": [
"class DiceModel(lightning.LightningModule):\n",
" def __init__(self, learning_rate=1e-3, batch_size=32, leaky=0.032, dropout=0.6):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
"\n",
" self.learning_rate = learning_rate\n",
" self.batch_size = batch_size\n",
" self.leaky = leaky\n",
" self.dropout = dropout\n",
"\n",
" self.model = torch.nn.Sequential(\n",
" torch.nn.Conv2d(in_channels=1, out_channels=24, kernel_size=7, stride=1, padding=0),\n",
" torch.nn.LeakyReLU(self.leaky),\n",
" torch.nn.MaxPool2d(2, 2),\n",
" torch.nn.Conv2d(in_channels=24, out_channels=48, kernel_size=3, stride=1, padding=0),\n",
" torch.nn.LeakyReLU(self.leaky),\n",
" torch.nn.MaxPool2d(3, 2),\n",
" torch.nn.Conv2d(in_channels=48, out_channels=96, kernel_size=3, stride=1, padding=0),\n",
" torch.nn.LeakyReLU(self.leaky),\n",
" torch.nn.MaxPool2d(2),\n",
" torch.nn.Conv2d(in_channels=96, out_channels=96*2, kernel_size=2, stride=1, padding=0),\n",
" torch.nn.LeakyReLU(self.leaky),\n",
" torch.nn.MaxPool2d(2, 1),\n",
" torch.nn.Conv2d(in_channels=96*2, out_channels=96*4, kernel_size=2, stride=1, padding=0),\n",
" torch.nn.LeakyReLU(self.leaky),\n",
" torch.nn.MaxPool2d(2, 1),\n",
" torch.nn.Dropout(self.dropout),\n",
" torch.nn.Flatten(),\n",
" torch.nn.Linear(96*4, 96*4),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(96*4, 6),\n",
" )\n",
" \n",
" self.loss = torch.nn.CrossEntropyLoss()\n",
"\n",
" acc = Accuracy()\n",
" self.train_acc = acc.clone()\n",
" self.valid_acc = acc.clone()\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
" def step(self, batch, batch_idx, name):\n",
" x, y = batch\n",
" y_hat = self.forward(x)\n",
" loss = self.loss(y_hat, y)\n",
" self.log(f\"{name}/loss\", loss)\n",
"\n",
" logits = self(x)\n",
" preds = torch.argmax(logits, dim=1)\n",
" \n",
" if name == \"train\":\n",
" self.train_acc(preds, y)\n",
" self.log(f\"{name}/acc\", self.train_acc)\n",
" else:\n",
" self.valid_acc(preds, y)\n",
" self.log(f\"{name}/acc\", self.valid_acc)\n",
"\n",
" return {\"loss\": loss}\n",
" \n",
" def training_step(self, batch, batch_nb):\n",
" return self.step(batch, batch_nb, name=\"train\")\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" return self.step(batch, batch_idx, name=\"val\")\n",
" \n",
" def test_step(self, batch, batch_idx):\n",
" return self.step(batch, batch_idx, name=\"test\")\n",
"\n",
" def configure_optimizers(self):\n",
" return torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(train_ds, batch_size=self.batch_size, shuffle=True)\n",
"\n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(val_ds, batch_size=self.batch_size)\n",
"\n",
"\n",
" def test_dataloader(self):\n",
" return torch.utils.data.DataLoader(test_ds, batch_size=self.batch_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81e96c38-6db9-4513-941e-dae059697cc0",
"metadata": {},
"outputs": [],
"source": [
"model = DiceModel(learning_rate=1e-3, batch_size=64, leaky=0.032, dropout=0.6)\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcb5a499-01a9-49e2-b8bd-c0b56d4c51b1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"_ = model.forward(next(iter(train_dataloader))[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf7215bc-fe86-42b5-a514-ce9692120650",
"metadata": {},
"outputs": [],
"source": [
"wandb_logger = lightning.loggers.WandbLogger(project=\"workshop-p3--D1C3\")\n",
"wandb_logger.watch(model)\n",
"trainer = lightning.Trainer(gpus=1, max_epochs=100, logger=wandb_logger, log_every_n_steps=10)\n",
"trainer.fit(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2cd44b93-f5d2-4bf4-8a30-be29424cc5e1",
"metadata": {},
"outputs": [],
"source": [
"trainer.test(ckpt_path='best')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16a30cc8-55da-4840-9e23-c5dc42c2afc1",
"metadata": {},
"outputs": [],
"source": [
"wandb.finish()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eff7de65-bdea-4a3a-8916-58397a790c08",
"metadata": {},
"outputs": [],
"source": [
"classes = ['1', '2', '3', '4', '5', '6']\n",
"classes"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b670575-f517-4801-ab5d-f93e64237b2f",
"metadata": {},
"outputs": [],
"source": [
"def matplotlib_imshow(img, one_channel=False):\n",
" if one_channel:\n",
" img = img.mean(dim=0)\n",
" img = img / 2 + 0.5 # unnormalize\n",
" npimg = img.numpy()\n",
" if one_channel:\n",
" plt.imshow(npimg, cmap=\"Greys\")\n",
" else:\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
" \n",
"def images_to_probs(net, images):\n",
" '''\n",
" Generates predictions and corresponding probabilities from a trained\n",
" network and a list of images\n",
" '''\n",
" output = net(images)\n",
" # convert output probabilities to predicted class\n",
" _, preds_tensor = torch.max(output, 1)\n",
" preds = np.squeeze(preds_tensor.numpy())\n",
" return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]\n",
"\n",
"def plot_classes_preds(net, images, labels, rows, cols):\n",
" '''\n",
" Generates matplotlib Figure using a trained network, along with images\n",
" and labels from a batch, that shows the network's top prediction along\n",
" with its probability, alongside the actual label, coloring this\n",
" information based on whether the prediction was correct or not.\n",
" Uses the \"images_to_probs\" function.\n",
" '''\n",
" preds, probs = images_to_probs(net, images)\n",
" # plot the images in the batch, along with predicted and true labels\n",
" fig = plt.figure(figsize=(10, 10))\n",
" for idx in np.arange(min([len(images), rows*cols])):\n",
" ax = fig.add_subplot(rows, cols, idx+1, xticks=[], yticks=[])\n",
" matplotlib_imshow(images[idx], one_channel=True)\n",
" ax.set_title(\"{0}, {1:.1f}%\\n(label: {2})\".format(\n",
" classes[preds[idx]],\n",
" probs[idx] * 100.0,\n",
" classes[labels[idx]]),\n",
" color=(\"green\" if preds[idx]==labels[idx].item() else \"red\"))\n",
" #return fig"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eea98740-9ce4-49d8-be0b-587776bd212b",
"metadata": {},
"outputs": [],
"source": [
"images, labels = next(iter(torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=True)))\n",
"plot_classes_preds(model, images, labels, 5,5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b99f232-2810-486a-9de7-63da58ee50fe",
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"\n",
"layers=[]\n",
"\n",
"model_children=list(model.children())\n",
"\n",
"for child in model_children:\n",
" for layer in child.children():\n",
" layers.append(layer)\n",
" \n",
"\n",
" \n",
"#layers = layers[0:7]\n",
"\n",
"layers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "559b6204-364b-47a5-a452-1ef2f07bea3d",
"metadata": {},
"outputs": [],
"source": [
"img = next(iter(model.train_dataloader()))[0][2]\n",
"plt.figure(figsize=(2,2))\n",
"plt.imshow(img.permute(1,2,0).data)\n",
"plt.axis(\"off\")\n",
"plt.show()\n",
"plt.close()\n",
"img = img.unsqueeze(0)\n",
"results = [layers[0](img)]\n",
"for i in range(1, len(layers)):\n",
" results.append(layers[i](results[-1]))\n",
"outputs = results\n",
"\n",
"for num_layer in range(len(outputs)):\n",
" plt.figure(figsize=(25, 20))\n",
" if outputs[num_layer].dim()==4:\n",
" layer_viz = outputs[num_layer][0, :, :, :]\n",
" else:\n",
" break\n",
" layer_viz = layer_viz.data\n",
" print(\"Layer\",num_layer+1, layers[num_layer])\n",
" for i, filter in enumerate(layer_viz):\n",
" if i == 14: \n",
" break\n",
" plt.subplot(1, 14, i + 1)\n",
" plt.imshow(filter)\n",
" plt.axis(\"off\")\n",
" plt.show()\n",
" plt.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment