Skip to content

Instantly share code, notes, and snippets.

@3sdd
Created March 25, 2020 13:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 3sdd/5e33cdcd9dfbb7529332968f83de7697 to your computer and use it in GitHub Desktop.
Save 3sdd/5e33cdcd9dfbb7529332968f83de7697 to your computer and use it in GitHub Desktop.
training_tqdm_pytorch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "training_tqdm_pytorch.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOSBXuXW8L/V9jXBmEuIzlR",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/3sdd/5e33cdcd9dfbb7529332968f83de7697/training_tqdm_pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ss2BD1_8r5mR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "f6485cad-87df-4b0e-c715-b23bdced0310"
},
"source": [
"!pip install tqdm==4.43.0"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: tqdm==4.43.0 in /usr/local/lib/python3.6/dist-packages (4.43.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0IS4hNW20Rtb",
"colab_type": "code",
"outputId": "3cb39b38-28ee-4d41-c1bc-49fb74d4a077",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 312
}
},
"source": [
"!nvidia-smi"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Wed Mar 25 10:28:50 2020 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 440.64.00 Driver Version: 418.67 CUDA Version: 10.1 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 80C P0 30W / 75W | 0MiB / 7611MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: GPU Memory |\n",
"| GPU PID Type Process name Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dAfu2uKTsHst",
"colab_type": "code",
"outputId": "bb187f7c-1ac7-4590-ebf9-6e7debbd9ea9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 69
}
},
"source": [
"import tqdm\n",
"import torch\n",
"import torchvision\n",
"\n",
"print(tqdm.__version__)\n",
"print(torch.__version__)\n",
"print(torchvision.__version__)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"4.43.0\n",
"1.4.0\n",
"0.5.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "T_CnDfrbsNPb",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"\n",
"from tqdm import tqdm\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aLW2c1fDsRNW",
"colab_type": "code",
"colab": {}
},
"source": [
"device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"batch_size=128\n",
"num_classes=10\n",
"num_epochs=20\n",
"lr=0.1"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "UAPCJxwlsZLR",
"colab_type": "code",
"outputId": "be45355c-c569-41f3-9a35-482f2aaccadd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
},
"source": [
"transform={\n",
" \"train\":transforms.Compose([\n",
" transforms.ToTensor(),\n",
" ]),\n",
" \"test\":transforms.Compose([\n",
" transforms.ToTensor(),\n",
" ])\n",
"}\n",
"\n",
"dataset={\n",
" \"train\":torchvision.datasets.CIFAR10(root=\"CIFAR10\",train=True,transform=transform[\"train\"],download=True),\n",
" \"test\":torchvision.datasets.CIFAR10(root=\"CIFAR10\",train=False,transform=transform[\"test\"],download=True)\n",
"}\n",
"\n",
"dataloader={\n",
" \"train\":torch.utils.data.DataLoader(dataset[\"train\"],batch_size=batch_size,shuffle=True,num_workers=2),\n",
" \"test\":torch.utils.data.DataLoader(dataset[\"test\"],batch_size=batch_size,shuffle=False,num_workers=2)\n",
"}"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wPcdOAPBtcZT",
"colab_type": "code",
"outputId": "e34879b7-6e6e-4404-944e-b489a7691707",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"model=torchvision.models.resnet18(pretrained=False)\n",
"\n",
"model.fc=nn.Linear(model.fc.in_features,num_classes)\n",
"model=model.to(device)\n",
"print(model)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): BasicBlock(\n",
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): BasicBlock(\n",
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" (fc): Linear(in_features=512, out_features=10, bias=True)\n",
")\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Iu9_D9Rft5Ry",
"colab_type": "code",
"colab": {}
},
"source": [
"optimizer=optim.Adam(model.parameters(),lr=lr)\n",
"criterion=nn.CrossEntropyLoss()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AXT9ulZd3RXM",
"colab_type": "code",
"colab": {}
},
"source": [
"def train(model,dataloader,otpimizer,criterion,num_epochs,device):\n",
"\n",
" for epoch in range(1,num_epochs+1):\n",
"\n",
" for phase in [\"train\",\"test\"]:\n",
"\n",
" if phase==\"train\":\n",
" model.train()\n",
" elif phase==\"test\":\n",
" model.eval()\n",
"\n",
" with torch.set_grad_enabled(phase==\"train\"):\n",
" loss_sum=0\n",
" corrects=0\n",
" total=0\n",
"\n",
" with tqdm(total=len(dataloader[phase]),unit=\"batch\") as pbar:\n",
" pbar.set_description(f\"Epoch[{epoch}/{num_epochs}]({phase})\")\n",
" for imgs,labels in dataloader[phase]: \n",
" imgs,labels=imgs.to(device),labels.to(device)\n",
" output=model(imgs)\n",
" loss=criterion(output,labels)\n",
"\n",
" if phase==\"train\":\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" predicted=torch.argmax(output,dim=1) ## dimあってる?\n",
" corrects+=(predicted==labels).sum()\n",
" total+=imgs.size(0)\n",
" \n",
" #loss関数で通してでてきたlossは平均化されているCrossEntropyLossのreduction=\"mean\"なので\n",
" #batch sizeをかけることで、batch全体での合計を今までのloss_sumに足し合わせる\n",
" loss_sum+=loss*imgs.size(0) #imgs.size(0)batch sizeの\n",
"\n",
" accuracy=corrects.item()/total\n",
" running_loss=loss_sum/total\n",
" pbar.set_postfix({\"loss\":running_loss.item(),\"accuracy\":accuracy })\n",
" pbar.update(1)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Xi-ncaX476P7",
"colab_type": "code",
"outputId": "7ca4f345-b328-4025-c87f-1d43c227a4d0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 712
}
},
"source": [
"train(model,dataloader,optimizer,criterion,num_epochs,device)"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch[1/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.29batch/s, loss=2.42, accuracy=0.141]\n",
"Epoch[1/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.38batch/s, loss=1.98, accuracy=0.188]\n",
"Epoch[2/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.14batch/s, loss=1.92, accuracy=0.202]\n",
"Epoch[2/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.69batch/s, loss=2.16, accuracy=0.197]\n",
"Epoch[3/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.23batch/s, loss=1.85, accuracy=0.239]\n",
"Epoch[3/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.42batch/s, loss=1.84, accuracy=0.246]\n",
"Epoch[4/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.14batch/s, loss=1.8, accuracy=0.268]\n",
"Epoch[4/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.47batch/s, loss=1.96, accuracy=0.272]\n",
"Epoch[5/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.02batch/s, loss=1.77, accuracy=0.288]\n",
"Epoch[5/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.67batch/s, loss=1.99, accuracy=0.273]\n",
"Epoch[6/20](train): 100%|██████████| 391/391 [00:23<00:00, 16.81batch/s, loss=1.68, accuracy=0.349]\n",
"Epoch[6/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.79batch/s, loss=1.59, accuracy=0.42]\n",
"Epoch[7/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.51batch/s, loss=1.46, accuracy=0.451]\n",
"Epoch[7/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.48batch/s, loss=1.48, accuracy=0.46]\n",
"Epoch[8/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.17batch/s, loss=1.38, accuracy=0.49]\n",
"Epoch[8/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.09batch/s, loss=1.46, accuracy=0.481]\n",
"Epoch[9/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.08batch/s, loss=1.29, accuracy=0.532]\n",
"Epoch[9/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.89batch/s, loss=1.38, accuracy=0.495]\n",
"Epoch[10/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.09batch/s, loss=1.22, accuracy=0.558]\n",
"Epoch[10/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.81batch/s, loss=1.38, accuracy=0.529]\n",
"Epoch[11/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.39batch/s, loss=1.15, accuracy=0.59]\n",
"Epoch[11/20](test): 100%|██████████| 79/79 [00:01<00:00, 40.72batch/s, loss=1.29, accuracy=0.557]\n",
"Epoch[12/20](train): 100%|██████████| 391/391 [00:23<00:00, 16.95batch/s, loss=1.09, accuracy=0.616]\n",
"Epoch[12/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.31batch/s, loss=1.37, accuracy=0.566]\n",
"Epoch[13/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.27batch/s, loss=1.04, accuracy=0.633]\n",
"Epoch[13/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.33batch/s, loss=2.13, accuracy=0.415]\n",
"Epoch[14/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.26batch/s, loss=0.995, accuracy=0.651]\n",
"Epoch[14/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.26batch/s, loss=1.36, accuracy=0.583]\n",
"Epoch[15/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.16batch/s, loss=0.949, accuracy=0.669]\n",
"Epoch[15/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.80batch/s, loss=1.58, accuracy=0.584]\n",
"Epoch[16/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.37batch/s, loss=1.05, accuracy=0.64]\n",
"Epoch[16/20](test): 100%|██████████| 79/79 [00:01<00:00, 43.20batch/s, loss=1.24, accuracy=0.605]\n",
"Epoch[17/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.13batch/s, loss=0.895, accuracy=0.693]\n",
"Epoch[17/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.94batch/s, loss=1.27, accuracy=0.612]\n",
"Epoch[18/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.21batch/s, loss=0.82, accuracy=0.72]\n",
"Epoch[18/20](test): 100%|██████████| 79/79 [00:01<00:00, 40.86batch/s, loss=1.5, accuracy=0.594]\n",
"Epoch[19/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.58batch/s, loss=0.79, accuracy=0.734]\n",
"Epoch[19/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.85batch/s, loss=1.31, accuracy=0.64]\n",
"Epoch[20/20](train): 100%|██████████| 391/391 [00:22<00:00, 17.29batch/s, loss=0.745, accuracy=0.75]\n",
"Epoch[20/20](test): 100%|██████████| 79/79 [00:01<00:00, 42.48batch/s, loss=1.43, accuracy=0.64]\n"
],
"name": "stderr"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment