Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sandeshregmi/372a2547fe347e48b93ed9238f1676be to your computer and use it in GitHub Desktop.
Save sandeshregmi/372a2547fe347e48b93ed9238f1676be to your computer and use it in GitHub Desktop.
Dataloader and Dataset Tutorial Pytorch Pytorch Lightning
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "7c305431-652e-4104-86fa-9353649b03c0",
"metadata": {},
"source": [
"# Dataset and imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "79907d77-7e41-4cea-8332-5803921c441c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Temperature</th>\n",
" <th>L</th>\n",
" <th>R</th>\n",
" <th>A_M</th>\n",
" <th>Color</th>\n",
" <th>Spectral_Class</th>\n",
" <th>Type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3068</td>\n",
" <td>0.002400</td>\n",
" <td>0.1700</td>\n",
" <td>16.12</td>\n",
" <td>Red</td>\n",
" <td>M</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3042</td>\n",
" <td>0.000500</td>\n",
" <td>0.1542</td>\n",
" <td>16.60</td>\n",
" <td>Red</td>\n",
" <td>M</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2600</td>\n",
" <td>0.000300</td>\n",
" <td>0.1020</td>\n",
" <td>18.70</td>\n",
" <td>Red</td>\n",
" <td>M</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2800</td>\n",
" <td>0.000200</td>\n",
" <td>0.1600</td>\n",
" <td>16.65</td>\n",
" <td>Red</td>\n",
" <td>M</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1939</td>\n",
" <td>0.000138</td>\n",
" <td>0.1030</td>\n",
" <td>20.06</td>\n",
" <td>Red</td>\n",
" <td>M</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>235</th>\n",
" <td>38940</td>\n",
" <td>374830.000000</td>\n",
" <td>1356.0000</td>\n",
" <td>-9.93</td>\n",
" <td>Blue</td>\n",
" <td>O</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>236</th>\n",
" <td>30839</td>\n",
" <td>834042.000000</td>\n",
" <td>1194.0000</td>\n",
" <td>-10.63</td>\n",
" <td>Blue</td>\n",
" <td>O</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>237</th>\n",
" <td>8829</td>\n",
" <td>537493.000000</td>\n",
" <td>1423.0000</td>\n",
" <td>-10.73</td>\n",
" <td>White</td>\n",
" <td>A</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>238</th>\n",
" <td>9235</td>\n",
" <td>404940.000000</td>\n",
" <td>1112.0000</td>\n",
" <td>-11.23</td>\n",
" <td>White</td>\n",
" <td>A</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>239</th>\n",
" <td>37882</td>\n",
" <td>294903.000000</td>\n",
" <td>1783.0000</td>\n",
" <td>-7.80</td>\n",
" <td>Blue</td>\n",
" <td>O</td>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>240 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" Temperature L R A_M Color Spectral_Class Type\n",
"0 3068 0.002400 0.1700 16.12 Red M 0\n",
"1 3042 0.000500 0.1542 16.60 Red M 0\n",
"2 2600 0.000300 0.1020 18.70 Red M 0\n",
"3 2800 0.000200 0.1600 16.65 Red M 0\n",
"4 1939 0.000138 0.1030 20.06 Red M 0\n",
".. ... ... ... ... ... ... ...\n",
"235 38940 374830.000000 1356.0000 -9.93 Blue O 5\n",
"236 30839 834042.000000 1194.0000 -10.63 Blue O 5\n",
"237 8829 537493.000000 1423.0000 -10.73 White A 5\n",
"238 9235 404940.000000 1112.0000 -11.23 White A 5\n",
"239 37882 294903.000000 1783.0000 -7.80 Blue O 5\n",
"\n",
"[240 rows x 7 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import pandas as pd\n",
"import numpy as np\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torchmetrics import Accuracy\n",
"import numpy as np\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"from sklearn import preprocessing\n",
"from enum import Enum \n",
"import pandas as pd\n",
"import copy\n",
"from pytorch_lightning import LightningModule, Trainer\n",
"from torch import nn\n",
"import torch\n",
"from torch.nn import functional as F\n",
"from torchmetrics import Accuracy\n",
"\"\"\"\n",
"TARGET:\n",
"Type\n",
"\n",
"\n",
"from 0 to 5\n",
"\n",
"Red Dwarf - 0\n",
"Brown Dwarf - 1\n",
"White Dwarf - 2\n",
"Main Sequence - 3\n",
"Super Giants - 4\n",
"Hyper Giants - 5\n",
"\"\"\"\n",
"df=pd.read_csv('Stars.csv')\n",
"df"
]
},
{
"cell_type": "markdown",
"id": "1843e6e8-a08b-4ff3-a5a7-a0ab7728ac1f",
"metadata": {},
"source": [
"# Why Batching?"
]
},
{
"cell_type": "markdown",
"id": "659b005e-ca4c-4804-a0f5-dfb02c45be17",
"metadata": {},
"source": [
"## You could load your entire dataset but usign batches is more efficient for training ML models"
]
},
{
"cell_type": "markdown",
"id": "90bfbe17-686e-40ec-87d5-bbb1da876480",
"metadata": {},
"source": [
"# What is Batching?\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "86ad4ea7-54ca-42ea-9c6e-ab519e3b8a63",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Batch size 1\n",
"7\n",
"7\n",
"7\n",
"\n",
"Batch size Max\n",
"240 7\n",
"\n",
"Proper Batching with size 8\n",
"8 7\n",
"8 7\n",
"8 7\n"
]
}
],
"source": [
"def fany_ml_logic(sample):\n",
" pass\n",
"df=pd.read_csv('Stars.csv')\n",
"\n",
"print(\"\\nBatch size 1\")\n",
"# Instead of iterating over the entire dataset\n",
"# this is baych size 1\n",
"for epoch in [1,2,3]:\n",
" for idx,row in df.iterrows():\n",
" print(row.size)\n",
" fany_ml_logic(row)\n",
" break\n",
"print(\"\\nBatch size Max\")\n",
"# This would be ideal if you have infinite memory\n",
"# -> You dont\n",
"# Load all data at once into ml model\n",
"fany_ml_logic(df)\n",
"print(len(df),df.iloc[0].size)\n",
"print(\"\\nProper Batching with size 8\")\n",
"# You iterate over the datset in chunks\n",
"# good trade off between waiting for memory loading and processing\n",
"chunk_size=8\n",
"for epoch in [1,2,3]:\n",
" for chunk_start in range(0,len(df),chunk_size):\n",
" chunk=df.iloc[chunk_start:chunk_start+chunk_size]\n",
" fany_ml_logic(chunk)\n",
" print(len(chunk),chunk.iloc[0].size)\n",
" break\n"
]
},
{
"cell_type": "markdown",
"id": "2a4fd20f-994c-4fa2-b479-785f7094f8e8",
"metadata": {},
"source": [
"# HOW, to do properly?\n",
"## -> Pytorch Dataloader and Dataset"
]
},
{
"cell_type": "markdown",
"id": "54defbb5-733a-4ef0-9cf8-878555010c22",
"metadata": {},
"source": [
"# MVP EXAMPLE"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a671615f-2d60-4619-8242-c7ceb650533b",
"metadata": {},
"outputs": [],
"source": [
"# Custom Dataset Class\n",
"## Needs ATLEAST 3 class methods\n",
"## __init__, __len__, __getitem__\n",
"\n",
"class CustomStarDataset(Dataset):\n",
" # This loads the data and converts it, make data rdy\n",
" def __init__(self):\n",
" # load data\n",
" self.df=pd.read_csv(\"Stars.csv\")\n",
" # extract labels\n",
" self.df_labels=df[['Type']]\n",
" # drop non numeric columns to make tutorial simpler, in real life do categorical encoding\n",
" self.df=df.drop(columns=['Type','Color','Spectral_Class'])\n",
" # conver to torch dtypes\n",
" self.dataset=torch.tensor(self.df.to_numpy()).float()\n",
"\n",
" self.labels=torch.tensor(self.df_labels.to_numpy().reshape(-1)).long()\n",
" \n",
" # This returns the total amount of samples in your Dataset\n",
" def __len__(self):\n",
" return len(self.dataset)\n",
" \n",
" # This returns given an index the i-th sample and label\n",
" def __getitem__(self, idx):\n",
" return self.dataset[idx],self.labels[idx]\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "b523de31-8a22-4393-b953-69528c267076",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<__main__.CustomStarDataset at 0x7f7c24db7090>"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create our dataset\n",
"ds=CustomStarDataset()\n",
"ds"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "734801d7-623e-4a41-82bb-268a44505d25",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(tensor([3.0680e+03, 2.4000e-03, 1.7000e-01, 1.6120e+01]), tensor(0))\n",
"(tensor([3.0420e+03, 5.0000e-04, 1.5420e-01, 1.6600e+01]), tensor(0))\n",
"(tensor([2.6000e+03, 3.0000e-04, 1.0200e-01, 1.8700e+01]), tensor(0))\n",
"(tensor([2.8000e+03, 2.0000e-04, 1.6000e-01, 1.6650e+01]), tensor(0))\n",
"(tensor([1.9390e+03, 1.3800e-04, 1.0300e-01, 2.0060e+01]), tensor(0))\n"
]
}
],
"source": [
"# what can it do?\n",
"# for now only return the ith sample and label\n",
"for i in range(min(5,ds.__len__())):\n",
" # access our ith sample\n",
" print(ds[i])"
]
},
{
"cell_type": "markdown",
"id": "1cf8d410-fcd3-4996-a27e-cd3651fb86eb",
"metadata": {},
"source": [
"## DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,\n",
"### batch_sampler=None, num_workers=0, collate_fn=None,\n",
"### pin_memory=False, drop_last=False, timeout=0,\n",
"### worker_init_fn=None, *, prefetch_factor=2,\n",
"### persistent_workers=False)"
]
},
{
"cell_type": "markdown",
"id": "941476f1-7cb1-47d5-9271-2d427e8467f6",
"metadata": {},
"source": [
"\n",
"# The arguments we rly care about\n",
"### dataset, -> what we created above\n",
"### batch_size=1 -> what influences the size of the batchs returned\n",
"### shuffle=False -> default is false!! Controls if we shuffle after each epoch\n",
"### drop_last=False -> deafault is false, comoon reason for errors!! Simply removes the last uneven batch\n",
"### num_workers -> controls parallelism, but uses more memory. Increase this if you see dataloading taking too much time in model\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "395b3fef-a203-4ea0-986c-3bd2e4ac0238",
"metadata": {},
"outputs": [],
"source": [
"# create a dataloader\n",
"dl=DataLoader(ds,batch_size=32, num_workers=2, shuffle=True,drop_last=True)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "913e572f-937b-456c-b80e-2bbea9bf1fb4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 3.7800e+03, 2.0000e+05, 1.3240e+03, -1.0700e+01],\n",
" [ 3.4190e+04, 1.9820e+05, 6.3900e+00, -4.5700e+00],\n",
" [ 1.2984e+04, 8.8000e-04, 9.9600e-03, 1.1230e+01],\n",
" [ 3.6000e+03, 2.4000e+05, 1.1900e+03, -7.8900e+00],\n",
" [ 3.1420e+03, 1.3200e-03, 2.5800e-01, 1.4120e+01],\n",
" [ 7.7200e+03, 7.9200e+00, 1.3400e+00, 2.4400e+00],\n",
" [ 5.7520e+03, 2.4500e+05, 9.7000e+01, -6.6300e+00],\n",
" [ 3.1920e+03, 3.6200e-03, 1.9670e-01, 1.3530e+01],\n",
" [ 5.8000e+03, 8.1000e-01, 9.0000e-01, 5.0500e+00],\n",
" [ 1.5680e+04, 1.2200e-03, 1.1400e-02, 1.1920e+01],\n",
" [ 1.9920e+04, 1.5600e-03, 1.4200e-02, 1.1340e+01],\n",
" [ 7.7230e+03, 1.4000e-04, 8.7800e-03, 1.4810e+01],\n",
" [ 3.3420e+03, 1.5000e-03, 3.0700e-01, 1.1870e+01],\n",
" [ 3.3990e+03, 1.1700e+05, 1.4860e+03, -1.0920e+01],\n",
" [ 1.4520e+04, 8.2000e-04, 9.7200e-03, 1.1920e+01],\n",
" [ 2.2012e+04, 6.7480e+03, 6.6400e+00, -2.5500e+00],\n",
" [ 1.7383e+04, 3.4290e+05, 3.0000e+01, -6.0900e+00],\n",
" [ 2.9140e+03, 6.3100e-04, 1.1600e-01, 1.8390e+01],\n",
" [ 9.3730e+03, 4.2452e+05, 2.4000e+01, -5.9900e+00],\n",
" [ 1.9360e+04, 1.2500e-03, 9.9800e-03, 1.1620e+01],\n",
" [ 2.9940e+03, 7.2000e-03, 2.8000e-01, 1.3450e+01],\n",
" [ 3.3450e+03, 2.1000e-02, 2.7300e-01, 1.2300e+01],\n",
" [ 2.1904e+04, 7.4849e+05, 1.1300e+03, -7.6700e+00],\n",
" [ 2.8310e+03, 2.3100e-04, 9.1500e-02, 1.6210e+01],\n",
" [ 3.5310e+03, 9.3000e-04, 9.7600e-02, 1.9940e+01],\n",
" [ 5.1120e+03, 6.3000e-01, 8.7600e-01, 4.6800e+00],\n",
" [ 9.3830e+03, 3.4294e+05, 9.8000e+01, -6.9800e+00],\n",
" [ 3.6500e+03, 3.1000e+05, 1.3240e+03, -7.7900e+00],\n",
" [ 2.8170e+03, 9.8000e-04, 9.1100e-02, 1.6450e+01],\n",
" [ 3.0910e+03, 8.1000e-03, 2.4000e-01, 1.1430e+01],\n",
" [ 3.1340e+03, 4.0000e-04, 1.9600e-01, 1.3210e+01],\n",
" [ 1.9860e+04, 1.1000e-03, 1.3100e-02, 1.1340e+01]]) tensor([5, 3, 2, 5, 1, 3, 4, 1, 3, 2, 2, 2, 1, 5, 2, 3, 4, 0, 4, 2, 1, 1, 5, 0,\n",
" 0, 3, 4, 5, 0, 1, 1, 2])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n",
"torch.Size([32, 4]) torch.Size([32])\n"
]
}
],
"source": [
"# use dataloader\n",
"counter=0\n",
"for epoch in [1,2,3]:\n",
" for batch,label in dl:\n",
" if counter==0:\n",
" print(batch,label)\n",
" counter+=1\n",
" print(batch.size(),label.size())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2e2f561e-1d01-46db-8ce9-2e7f9eb3364a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result is bad, we will see why in a minute!\n",
"tensor(0.4042)\n"
]
}
],
"source": [
"# Minmal Example using Pytorch\n",
"# Define some simple model to get started\n",
"class SimpleModel(torch.nn.Module):\n",
" def __init__(self):\n",
" super(SimpleModel, self).__init__()\n",
" self.model=nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(4, 128),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(128, 32),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(32, 6),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" z = self.model(x)\n",
" return F.log_softmax(z, dim=1)\n",
"\n",
"# Create model\n",
"model=SimpleModel().to('cpu')\n",
"# Create dataloader with sensible parameters\n",
"dl=DataLoader(ds,batch_size=4, num_workers=1, shuffle=True,drop_last=True)\n",
"# optimizer creation for training\n",
"optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)\n",
"model.train()\n",
"# training loop\n",
"for epoch in range(20):\n",
" for batch, label in dl:\n",
" # for each batch do a forward pass\n",
" optimizer.zero_grad()\n",
" oupt = model(batch)\n",
" # calculate the loss\n",
" loss_obj = F.nll_loss(oupt, label)\n",
" # updates\n",
" loss_obj.backward()\n",
" optimizer.step()\n",
"#Do evaluation and the rest\n",
"model.eval()\n",
"# We simply evaluate on the training set, this is bad\n",
"# But also the results will be bad, we will make it better in the Full Example version\n",
"acc=Accuracy()\n",
"# evaluation simple, quick to make tutorial fast\n",
"dl=DataLoader(ds,batch_size=240, num_workers=2, shuffle=False,drop_last=True)\n",
"for batch, label in dl:\n",
" print(\"Result is bad, we will see why in a minute!\")\n",
" print(acc(torch.argmax(model(batch),dim=1),label))\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "9effdf36-dcb3-4d82-96e3-d4ae354951a3",
"metadata": {},
"source": [
"# Full Example"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "bc409d73-574d-474a-91bb-e68efc9311f7",
"metadata": {},
"outputs": [],
"source": [
"# Make simple Enum for code clarity\n",
"class DatasetType(Enum):\n",
" TRAIN = 1\n",
" TEST = 2\n",
" VAL = 3\n",
"\n",
"# Again create a Dataset but this time, do the split in train test val\n",
"class CustomStarDataset(Dataset):\n",
" def __init__(self):\n",
" # load data and shuffle, befor splitting\n",
" self.df=pd.read_csv(\"Stars.csv\").sample(frac=1, random_state=27)\n",
" train_split=0.6\n",
" val_split=0.8\n",
" self.df_labels=df[['Type']]\n",
" # drop non numeric columns, in real life do categorical encoding\n",
" self.df=df.drop(columns=['Type','Color','Spectral_Class'])\n",
" # split pointf.df\n",
" self.train, self.val, self.test = np.split(self.df, [int(train_split*len(self.df)), int(val_split*len(self.df))])\n",
" self.train_labels, self.val_labels, self.test_labels = np.split(self.df_labels, [int(train_split*len(self.df)), int(val_split*len(self.df))])\n",
" # do the feature scaling only on the train set!\n",
" self.scaler=preprocessing.StandardScaler().fit(self.train)\n",
" for data_split in [ self.train, self.val, self.test]:\n",
" data_split=self.scaler.transform(data_split)\n",
" # convet labels to 1 hot\n",
"\n",
" \n",
" def __len__(self):\n",
" return len(self.dataset)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.dataset[idx],self.labels[idx]\n",
" \n",
" def set_fold(self,set_type):\n",
" # Make sure to call this befor using the dataset\n",
" if set_type==DatasetType.TRAIN:\n",
" self.dataset,self.labels=self.train,self.train_labels\n",
" if set_type==DatasetType.TEST:\n",
" self.dataset,self.labels=self.test,self.test_labels\n",
" if set_type==DatasetType.VAL:\n",
" self.dataset,self.labels=self.val,self.val_labels\n",
" # Convert the datasets and the labels to pytorch format\n",
" # Also use the StdScaler on the training set\n",
" self.dataset=torch.tensor(self.scaler.transform(self.dataset)).float()\n",
" self.labels=torch.tensor(self.labels.to_numpy().reshape(-1)).long()\n",
"\n",
" return self\n",
" \n",
"dataset=CustomStarDataset()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "817a6ce1-cd0c-457a-adcd-f7922fb2a720",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0)\n",
"tensor(1)\n",
"tensor(2)\n",
"CPU times: user 8.59 ms, sys: 4.22 ms, total: 12.8 ms\n",
"Wall time: 10.1 ms\n"
]
}
],
"source": [
"%%time\n",
"# This is faster then recreating the dataset\n",
"# but you could also simply create it 3 times\n",
"# Or use PyTorchs build in train,test,val split\n",
"train=copy.deepcopy(dataset).set_fold(DatasetType.TRAIN)\n",
"test=copy.deepcopy(dataset).set_fold(DatasetType.TEST)\n",
"val=copy.deepcopy(dataset).set_fold(DatasetType.VAL)\n",
"for i in [train,test,val]:\n",
" print(i.__getitem__(0)[1])"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "697e8ab5-f965-4556-b328-cfe0a31b6a91",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"144"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "cca5f337-aef9-48a4-915f-b6698ed70ffc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"124\n",
"10\n",
"10\n"
]
}
],
"source": [
"# Easy alternative, using Pytorchs build in method\n",
"train_example, val_example, test_example = torch.utils.data.random_split(train, [124 ,10, 10])\n",
"for i in [train_example, val_example, test_example]:\n",
" print(len(i))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "7581ed4d-6f02-4ae0-8916-084ee14b8d48",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:97: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.\n",
" f\"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and\"\n",
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------------\n",
"0 | model | Sequential | 5.0 K \n",
"1 | accuracy | Accuracy | 0 \n",
"----------------------------------------\n",
"5.0 K Trainable params\n",
"0 Non-trainable params\n",
"5.0 K Total params\n",
"0.020 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:245: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" category=PossibleUserWarning,\n",
"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:245: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" category=PossibleUserWarning,\n",
"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py:1931: PossibleUserWarning: The number of training batches (36) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
" category=PossibleUserWarning,\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "36968e15e3d441ceb096297281c294b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py:1445: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `test(ckpt_path='best')` to use and best model checkpoint and avoid this warning or `ckpt_path=trainer.checkpoint_callback.last_model_path` to use the last model.\n",
" f\"`.{fn}(ckpt_path=None)` was called without a model.\"\n",
"Restoring states from the checkpoint path at /home/jupyter/youtubeDataset/lightning_logs/version_25/checkpoints/epoch=9-step=360.ckpt\n",
"Loaded model weights from checkpoint at /home/jupyter/youtubeDataset/lightning_logs/version_25/checkpoints/epoch=9-step=360.ckpt\n",
"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:245: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" category=PossibleUserWarning,\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c20ed5fffe9a4a6a948c1f08fd9d4f57",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Testing: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" test_acc 0.8541666865348816\n",
" test_loss 0.44730350375175476\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
]
},
{
"data": {
"text/plain": [
"[{'test_loss': 0.44730350375175476, 'test_acc': 0.8541666865348816}]"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"# Define Batch Size\n",
"BATCH_SIZE=4\n",
"\n",
"# Defin a SimpleLightning Model\n",
"class SimpleModel(LightningModule):\n",
" def __init__(self,train,test,val):\n",
" super().__init__()\n",
" self.train_ds=train\n",
" self.val_ds=val\n",
" self.test_ds=test\n",
" # Define PyTorch model\n",
" classes=6\n",
" features=4\n",
" self.model = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(features, 128),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(128, 32),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(32, classes),\n",
" )\n",
" self.accuracy = Accuracy()\n",
" # Same as above\n",
" def forward(self, x):\n",
" x = self.model(x)\n",
" return F.log_softmax(x, dim=1)\n",
" \n",
" # Same as above\n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" \n",
" return loss\n",
" \n",
" # Make use of the validation set\n",
" def validation_step(self, batch, batch_idx, print_str=\"val\"):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" self.accuracy(preds, y)\n",
"\n",
" # Calling self.log will surface up scalars for you in TensorBoard\n",
" self.log(f\"{print_str}_loss\", loss, prog_bar=True)\n",
" self.log(f\"{print_str}_acc\", self.accuracy, prog_bar=True)\n",
" return loss\n",
" \n",
" def test_step(self, batch, batch_idx):\n",
" # Here we just reuse the validation_step for testing\n",
" return self.validation_step(batch, batch_idx,print_str='test')\n",
" \n",
" def configure_optimizers(self):\n",
" return torch.optim.Adam(self.parameters(), lr=0.001)\n",
" #\n",
" # HERE: We define the 3 Dataloaders, only train needs to be shuffled\n",
" # This will then directly be usable with Pytorch Lightning to make a super quick model\n",
" def train_dataloader(self):\n",
" return DataLoader(self.train_ds, batch_size=BATCH_SIZE,shuffle=True)\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(self.val_ds, batch_size=BATCH_SIZE,shuffle=False)\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(self.test_ds, batch_size=BATCH_SIZE,shuffle=False)\n",
"\n",
"# Start the Trainer\n",
"trainer = Trainer(\n",
" max_epochs=10,\n",
" progress_bar_refresh_rate=1,\n",
")\n",
"# Define the Model\n",
"model=SimpleModel(train,test,val)\n",
"# Train the Model\n",
"trainer.fit(model)\n",
"# Test on the Test SET, it will print validation\n",
"trainer.test()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c97f36b-4e9a-4e74-abe2-813d99d59342",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "conda-root-py",
"name": "pytorch-gpu.1-9.m82",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-9:m82"
},
"kernelspec": {
"display_name": "Python [conda env:root] *",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment