Created
March 25, 2020 13:00
-
-
Save 3sdd/5e33cdcd9dfbb7529332968f83de7697 to your computer and use it in GitHub Desktop.
training_tqdm_pytorch.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "training_tqdm_pytorch.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyOSBXuXW8L/V9jXBmEuIzlR", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/3sdd/5e33cdcd9dfbb7529332968f83de7697/training_tqdm_pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Ss2BD1_8r5mR", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "f6485cad-87df-4b0e-c715-b23bdced0310" | |
}, | |
"source": [ | |
"!pip install tqdm==4.43.0" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: tqdm==4.43.0 in /usr/local/lib/python3.6/dist-packages (4.43.0)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0IS4hNW20Rtb", | |
"colab_type": "code", | |
"outputId": "3cb39b38-28ee-4d41-c1bc-49fb74d4a077", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 312 | |
} | |
}, | |
"source": [ | |
"!nvidia-smi" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Wed Mar 25 10:28:50 2020 \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| NVIDIA-SMI 440.64.00 Driver Version: 418.67 CUDA Version: 10.1 |\n", | |
"|-------------------------------+----------------------+----------------------+\n", | |
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", | |
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", | |
"|===============================+======================+======================|\n", | |
"| 0 Tesla P4 Off | 00000000:00:04.0 Off | 0 |\n", | |
"| N/A 80C P0 30W / 75W | 0MiB / 7611MiB | 0% Default |\n", | |
"+-------------------------------+----------------------+----------------------+\n", | |
" \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| Processes: GPU Memory |\n", | |
"| GPU PID Type Process name Usage |\n", | |
"|=============================================================================|\n", | |
"| No running processes found |\n", | |
"+-----------------------------------------------------------------------------+\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dAfu2uKTsHst", | |
"colab_type": "code", | |
"outputId": "bb187f7c-1ac7-4590-ebf9-6e7debbd9ea9", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 69 | |
} | |
}, | |
"source": [ | |
"import tqdm\n", | |
"import torch\n", | |
"import torchvision\n", | |
"\n", | |
"print(tqdm.__version__)\n", | |
"print(torch.__version__)\n", | |
"print(torchvision.__version__)" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"4.43.0\n", | |
"1.4.0\n", | |
"0.5.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "T_CnDfrbsNPb", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"import torchvision\n", | |
"import torchvision.transforms as transforms\n", | |
"\n", | |
"from tqdm import tqdm\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aLW2c1fDsRNW", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"batch_size=128\n", | |
"num_classes=10\n", | |
"num_epochs=20\n", | |
"lr=0.1" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "UAPCJxwlsZLR", | |
"colab_type": "code", | |
"outputId": "be45355c-c569-41f3-9a35-482f2aaccadd", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 52 | |
} | |
}, | |
"source": [ | |
"transform={\n", | |
" \"train\":transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" ]),\n", | |
" \"test\":transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" ])\n", | |
"}\n", | |
"\n", | |
"dataset={\n", | |
" \"train\":torchvision.datasets.CIFAR10(root=\"CIFAR10\",train=True,transform=transform[\"train\"],download=True),\n", | |
" \"test\":torchvision.datasets.CIFAR10(root=\"CIFAR10\",train=False,transform=transform[\"test\"],download=True)\n", | |
"}\n", | |
"\n", | |
"dataloader={\n", | |
" \"train\":torch.utils.data.DataLoader(dataset[\"train\"],batch_size=batch_size,shuffle=True,num_workers=2),\n", | |
" \"test\":torch.utils.data.DataLoader(dataset[\"test\"],batch_size=batch_size,shuffle=False,num_workers=2)\n", | |
"}" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Files already downloaded and verified\n", | |
"Files already downloaded and verified\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wPcdOAPBtcZT", | |
"colab_type": "code", | |
"outputId": "e34879b7-6e6e-4404-944e-b489a7691707", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
} | |
}, | |
"source": [ | |
"model=torchvision.models.resnet18(pretrained=False)\n", | |
"\n", | |
"model.fc=nn.Linear(model.fc.in_features,num_classes)\n", | |
"model=model.to(device)\n", | |
"print(model)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"ResNet(\n", | |
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", | |
" (layer1): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer2): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer3): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (layer4): Sequential(\n", | |
" (0): BasicBlock(\n", | |
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): BasicBlock(\n", | |
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (relu): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", | |
" (fc): Linear(in_features=512, out_features=10, bias=True)\n", | |
")\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Iu9_D9Rft5Ry", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"optimizer=optim.Adam(model.parameters(),lr=lr)\n", | |
"criterion=nn.CrossEntropyLoss()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AXT9ulZd3RXM", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def train(model,dataloader,otpimizer,criterion,num_epochs,device):\n", | |
"\n", | |
" for epoch in range(1,num_epochs+1):\n", | |
"\n", | |
" for phase in [\"train\",\"test\"]:\n", | |
"\n", | |
" if phase==\"train\":\n", | |
" model.train()\n", | |
" elif phase==\"test\":\n", | |
" model.eval()\n", | |
"\n", | |
" with torch.set_grad_enabled(phase==\"train\"):\n", | |
" loss_sum=0\n", | |
" corrects=0\n", | |
" total=0\n", | |
"\n", | |
" with tqdm(total=len(dataloader[phase]),unit=\"batch\") as pbar:\n", | |
" pbar.set_description(f\"Epoch[{epoch}/{num_epochs}]({phase})\")\n", | |
" for imgs,labels in dataloader[phase]: \n", | |
" imgs,labels=imgs.to(device),labels.to(device)\n", | |
" output=model(imgs)\n", | |
" loss=criterion(output,labels)\n", | |
"\n", | |
" if phase==\"train\":\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" predicted=torch.argmax(output,dim=1) ## dimあってる?\n", | |
" corrects+=(predicted==labels).sum()\n", | |
" total+=imgs.size(0)\n", | |
" \n", | |
" #loss関数で通してでてきたlossは平均化されているCrossEntropyLossのreduction=\"mean\"なので\n", | |
" #batch sizeをかけることで、batch全体での合計を今までのloss_sumに足し合わせる\n", | |
" loss_sum+=loss*imgs.size(0) #imgs.size(0)batch sizeの\n", | |
"\n", | |
" accuracy=corrects.item()/total\n", | |
" running_loss=loss_sum/total\n", | |
" pbar.set_postfix({\"loss\":running_loss.item(),\"accuracy\":accuracy })\n", | |
" pbar.update(1)\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Xi-ncaX476P7", | |
"colab_type": "code", | |
"outputId": "7ca4f345-b328-4025-c87f-1d43c227a4d0", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 712 | |
} | |
}, | |
"source": [ | |
"train(model,dataloader,optimizer,criterion,num_epochs,device)" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch[1/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.29batch/s, loss=2.42, accuracy=0.141]\n", | |
"Epoch[1/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.38batch/s, loss=1.98, accuracy=0.188]\n", | |
"Epoch[2/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.14batch/s, loss=1.92, accuracy=0.202]\n", | |
"Epoch[2/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.69batch/s, loss=2.16, accuracy=0.197]\n", | |
"Epoch[3/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.23batch/s, loss=1.85, accuracy=0.239]\n", | |
"Epoch[3/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.42batch/s, loss=1.84, accuracy=0.246]\n", | |
"Epoch[4/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.14batch/s, loss=1.8, accuracy=0.268]\n", | |
"Epoch[4/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.47batch/s, loss=1.96, accuracy=0.272]\n", | |
"Epoch[5/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.02batch/s, loss=1.77, accuracy=0.288]\n", | |
"Epoch[5/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.67batch/s, loss=1.99, accuracy=0.273]\n", | |
"Epoch[6/20](train): 100%|██████████| 391/391 [00:23<00:00, 16.81batch/s, loss=1.68, accuracy=0.349]\n", | |
"Epoch[6/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.79batch/s, loss=1.59, accuracy=0.42]\n", | |
"Epoch[7/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.51batch/s, loss=1.46, accuracy=0.451]\n", | |
"Epoch[7/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.48batch/s, loss=1.48, accuracy=0.46]\n", | |
"Epoch[8/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.17batch/s, loss=1.38, accuracy=0.49]\n", | |
"Epoch[8/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.09batch/s, loss=1.46, accuracy=0.481]\n", | |
"Epoch[9/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.08batch/s, loss=1.29, accuracy=0.532]\n", | |
"Epoch[9/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.89batch/s, loss=1.38, accuracy=0.495]\n", | |
"Epoch[10/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.09batch/s, loss=1.22, accuracy=0.558]\n", | |
"Epoch[10/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.81batch/s, loss=1.38, accuracy=0.529]\n", | |
"Epoch[11/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.39batch/s, loss=1.15, accuracy=0.59]\n", | |
"Epoch[11/20](test): 100%|██████████| 79/79 [00:01<00:00, 40.72batch/s, loss=1.29, accuracy=0.557]\n", | |
"Epoch[12/20](train): 100%|██████████| 391/391 [00:23<00:00, 16.95batch/s, loss=1.09, accuracy=0.616]\n", | |
"Epoch[12/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.31batch/s, loss=1.37, accuracy=0.566]\n", | |
"Epoch[13/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.27batch/s, loss=1.04, accuracy=0.633]\n", | |
"Epoch[13/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.33batch/s, loss=2.13, accuracy=0.415]\n", | |
"Epoch[14/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.26batch/s, loss=0.995, accuracy=0.651]\n", | |
"Epoch[14/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.26batch/s, loss=1.36, accuracy=0.583]\n", | |
"Epoch[15/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.16batch/s, loss=0.949, accuracy=0.669]\n", | |
"Epoch[15/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.80batch/s, loss=1.58, accuracy=0.584]\n", | |
"Epoch[16/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.37batch/s, loss=1.05, accuracy=0.64]\n", | |
"Epoch[16/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.20batch/s, loss=1.24, accuracy=0.605]\n", | |
"Epoch[17/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.13batch/s, loss=0.895, accuracy=0.693]\n", | |
"Epoch[17/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.94batch/s, loss=1.27, accuracy=0.612]\n", | |
"Epoch[18/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.21batch/s, loss=0.82, accuracy=0.72]\n", | |
"Epoch[18/20](test): 100%|██████████| 79/79 [00:01<00:00, 40.86batch/s, loss=1.5, accuracy=0.594]\n", | |
"Epoch[19/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.58batch/s, loss=0.79, accuracy=0.734]\n", | |
"Epoch[19/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.85batch/s, loss=1.31, accuracy=0.64]\n", | |
"Epoch[20/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.29batch/s, loss=0.745, accuracy=0.75]\n", | |
"Epoch[20/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.48batch/s, loss=1.43, accuracy=0.64]\n" | |
], | |
"name": "stderr" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment