Skip to content

Instantly share code, notes, and snippets.

@okwrtdsh
Created March 1, 2020 13:39
Show Gist options
  • Save okwrtdsh/18d23228bd182cec5bb0eb96b13d4e19 to your computer and use it in GitHub Desktop.
Save okwrtdsh/18d23228bd182cec5bb0eb96b13d4e19 to your computer and use it in GitHub Desktop.
train_C3D.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "train_C3D.ipynb",
"provenance": [],
"collapsed_sections": [],
"machine_shape": "hm",
"authorship_tag": "ABX9TyPgoHi+tDvvDPYO7QZp4NZN",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/okwrtdsh/18d23228bd182cec5bb0eb96b13d4e19/train_c3d.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "REzmuopT1Hrx",
"colab_type": "code",
"outputId": "932005bd-df91-43cb-9bc2-0b13494d29e8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 938
}
},
"source": [
"!pip install -U torchsummary catalyst"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already up-to-date: torchsummary in /usr/local/lib/python3.6/dist-packages (1.5.1)\n",
"Requirement already up-to-date: catalyst in /usr/local/lib/python3.6/dist-packages (20.2.4)\n",
"Requirement already satisfied, skipping upgrade: tqdm>=4.33.0 in /usr/local/lib/python3.6/dist-packages (from catalyst) (4.43.0)\n",
"Requirement already satisfied, skipping upgrade: ipython in /usr/local/lib/python3.6/dist-packages (from catalyst) (5.5.0)\n",
"Requirement already satisfied, skipping upgrade: tensorboardX in /usr/local/lib/python3.6/dist-packages (from catalyst) (2.0)\n",
"Requirement already satisfied, skipping upgrade: opencv-python in /usr/local/lib/python3.6/dist-packages (from catalyst) (4.1.2.30)\n",
"Requirement already satisfied, skipping upgrade: imageio in /usr/local/lib/python3.6/dist-packages (from catalyst) (2.4.1)\n",
"Requirement already satisfied, skipping upgrade: scikit-image>=0.14.2 in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.16.2)\n",
"Requirement already satisfied, skipping upgrade: seaborn in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.10.0)\n",
"Requirement already satisfied, skipping upgrade: scikit-learn>=0.20 in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.22.1)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.16.4 in /usr/local/lib/python3.6/dist-packages (from catalyst) (1.17.5)\n",
"Requirement already satisfied, skipping upgrade: GitPython>=2.1.11 in /usr/local/lib/python3.6/dist-packages (from catalyst) (3.1.0)\n",
"Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.6/dist-packages (from catalyst) (20.1)\n",
"Requirement already satisfied, skipping upgrade: pandas>=0.22 in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.25.3)\n",
"Requirement already satisfied, skipping upgrade: Pillow<7 in /usr/local/lib/python3.6/dist-packages (from catalyst) (6.2.2)\n",
"Requirement already satisfied, skipping upgrade: safitty>=1.2.3 in /usr/local/lib/python3.6/dist-packages (from catalyst) (1.3)\n",
"Requirement already satisfied, skipping upgrade: crc32c>=1.7 in /usr/local/lib/python3.6/dist-packages (from catalyst) (2.0)\n",
"Requirement already satisfied, skipping upgrade: torch>=1.0.0 in /usr/local/lib/python3.6/dist-packages (from catalyst) (1.4.0)\n",
"Requirement already satisfied, skipping upgrade: torchvision>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.5.0)\n",
"Requirement already satisfied, skipping upgrade: matplotlib in /usr/local/lib/python3.6/dist-packages (from catalyst) (3.1.3)\n",
"Requirement already satisfied, skipping upgrade: PyYAML in /usr/local/lib/python3.6/dist-packages (from catalyst) (3.13)\n",
"Requirement already satisfied, skipping upgrade: plotly>=4.1.0 in /usr/local/lib/python3.6/dist-packages (from catalyst) (4.4.1)\n",
"Requirement already satisfied, skipping upgrade: tensorboard>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from catalyst) (1.15.0)\n",
"Requirement already satisfied, skipping upgrade: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (1.0.18)\n",
"Requirement already satisfied, skipping upgrade: simplegeneric>0.8 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (0.8.1)\n",
"Requirement already satisfied, skipping upgrade: pickleshare in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (0.7.5)\n",
"Requirement already satisfied, skipping upgrade: decorator in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (4.4.1)\n",
"Requirement already satisfied, skipping upgrade: pygments in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (2.1.3)\n",
"Requirement already satisfied, skipping upgrade: pexpect; sys_platform != \"win32\" in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (4.8.0)\n",
"Requirement already satisfied, skipping upgrade: setuptools>=18.5 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (45.1.0)\n",
"Requirement already satisfied, skipping upgrade: traitlets>=4.2 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (4.3.3)\n",
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from tensorboardX->catalyst) (1.12.0)\n",
"Requirement already satisfied, skipping upgrade: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorboardX->catalyst) (3.10.0)\n",
"Requirement already satisfied, skipping upgrade: PyWavelets>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.14.2->catalyst) (1.1.1)\n",
"Requirement already satisfied, skipping upgrade: networkx>=2.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.14.2->catalyst) (2.4)\n",
"Requirement already satisfied, skipping upgrade: scipy>=0.19.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.14.2->catalyst) (1.4.1)\n",
"Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.20->catalyst) (0.14.1)\n",
"Requirement already satisfied, skipping upgrade: gitdb<5,>=4.0.1 in /usr/local/lib/python3.6/dist-packages (from GitPython>=2.1.11->catalyst) (4.0.2)\n",
"Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->catalyst) (2.4.6)\n",
"Requirement already satisfied, skipping upgrade: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22->catalyst) (2.6.1)\n",
"Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22->catalyst) (2018.9)\n",
"Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->catalyst) (1.1.0)\n",
"Requirement already satisfied, skipping upgrade: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->catalyst) (0.10.0)\n",
"Requirement already satisfied, skipping upgrade: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly>=4.1.0->catalyst) (1.3.3)\n",
"Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (1.0.0)\n",
"Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (3.2.1)\n",
"Requirement already satisfied, skipping upgrade: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (0.9.0)\n",
"Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (0.34.2)\n",
"Requirement already satisfied, skipping upgrade: grpcio>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (1.27.1)\n",
"Requirement already satisfied, skipping upgrade: wcwidth in /usr/local/lib/python3.6/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython->catalyst) (0.1.8)\n",
"Requirement already satisfied, skipping upgrade: ptyprocess>=0.5 in /usr/local/lib/python3.6/dist-packages (from pexpect; sys_platform != \"win32\"->ipython->catalyst) (0.6.0)\n",
"Requirement already satisfied, skipping upgrade: ipython-genutils in /usr/local/lib/python3.6/dist-packages (from traitlets>=4.2->ipython->catalyst) (0.2.0)\n",
"Requirement already satisfied, skipping upgrade: smmap<4,>=3.0.1 in /usr/local/lib/python3.6/dist-packages (from gitdb<5,>=4.0.1->GitPython>=2.1.11->catalyst) (3.0.1)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LbDQO9bEy5ii",
"colab_type": "code",
"outputId": "90e83bc0-fd2f-453d-8610-521026c57bd7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"\"\"\"\n",
"Mount Google Drive\n",
"\"\"\"\n",
"# from google.colab import drive\n",
"# drive.mount('/content/gdrive')\n",
"# !ls 'gdrive/My Drive/Colab Notebooks/'"
],
"execution_count": 2,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'\\nMount Google\\u3000Drive\\n'"
]
},
"metadata": {
"tags": []
},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8vxOid9x15rp",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 63
},
"outputId": "3aae6d24-1fd6-474f-e4a1-a3e769c94132"
},
"source": [
"\"\"\"\n",
"Reproducibility\n",
"\"\"\"\n",
"SEED = 123\n",
"\n",
"import os\n",
"os.environ['PYTHONHASHSEED'] = '0'\n",
"\n",
"import random\n",
"random.seed(SEED)\n",
"\n",
"import numpy as np\n",
"np.random.seed(SEED)\n",
"\n",
"import torch\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cudnn.benchmark = False\n",
"\n",
"from catalyst.utils import set_global_seed\n",
"set_global_seed(SEED)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<p style=\"color: red;\">\n",
"The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
"We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
"or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
"<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6B3QBlkN6Ppo",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "a07be55a-9be6-4488-d54e-5ab511ace61d"
},
"source": [
"import math\n",
"\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"import albumentations as albu\n",
"from albumentations.pytorch import ToTensor\n",
"\n",
"from catalyst.dl.runner import SupervisedRunner\n",
"from catalyst.dl.callbacks import AccuracyCallback, EarlyStoppingCallback, InferCallback, CheckpointCallback, MixupCallback\n",
"\n",
"from torchsummary import summary"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"alchemy not available, to install alchemy, run `pip install alchemy-catalyst`.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LMux-J5gu3I1",
"colab_type": "code",
"colab": {}
},
"source": [
"class C3D(nn.Module):\n",
" def __init__(self, num_classes=101):\n",
" super().__init__()\n",
" self.conv1a = nn.Conv3d(3, 64, 3, stride=1, padding=1, bias=True)\n",
" self.conv2a = nn.Conv3d(64, 128, 3, stride=1, padding=1, bias=True)\n",
" self.conv3a = nn.Conv3d(128, 256, 3, 1, padding=1)\n",
" self.conv3b = nn.Conv3d(256, 256, 3, 1, padding=1)\n",
" self.conv4a = nn.Conv3d(256, 512, 3, 1, padding=1)\n",
" self.conv4b = nn.Conv3d(512, 512, 3, 1, padding=1)\n",
" self.conv5a = nn.Conv3d(512, 512, 3, 1, padding=1)\n",
" self.conv5b = nn.Conv3d(512, 512, 3, 1, padding=1)\n",
" self.fc6 = nn.Linear(512*4*4, 4096)\n",
" self.fc7 = nn.Linear(4096, 4096)\n",
" self.fc8 = nn.Linear(4096, num_classes)\n",
"\n",
" for m in self.modules():\n",
" if isinstance(m, nn.Conv2d):\n",
" n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
" m.weight.data.normal_(0, math.sqrt(2. / n))\n",
" if m.bias is not None:\n",
" m.bias.data.zero_()\n",
" elif isinstance(m, nn.Conv3d):\n",
" n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels\n",
" m.weight.data.normal_(0, math.sqrt(2. / n))\n",
" if m.bias is not None:\n",
" m.bias.data.zero_()\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.conv1a(x))\n",
" x = F.max_pool3d(x, (1, 2, 2), stride=(1, 2, 2))\n",
"\n",
" x = F.relu(self.conv2a(x))\n",
" x = F.max_pool3d(x, 2, 2)\n",
"\n",
" x = F.relu(self.conv3a(x))\n",
" x = F.relu(self.conv3b(x))\n",
" x = F.max_pool3d(x, 2, 2)\n",
"\n",
" x = F.relu(self.conv4a(x))\n",
" x = F.relu(self.conv4b(x))\n",
" x = F.max_pool3d(x, 2, 2)\n",
"\n",
" x = F.relu(self.conv5a(x))\n",
" x = F.relu(self.conv5b(x))\n",
" x = F.max_pool3d(x, 2, 2, padding=(0, 1, 1))\n",
"\n",
" x = x.view(-1, 512*4*4)\n",
" x = F.relu(self.fc6(x))\n",
" x = F.relu(self.fc7(x))\n",
" return F.softmax(self.fc8(x), dim=1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8q6IO6o61Oor",
"colab_type": "code",
"colab": {}
},
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = C3D(101).to(device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VOmvCqljyk5c",
"colab_type": "code",
"outputId": "3fce92a5-95de-4a6e-b151-050943cd0ca6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 425
}
},
"source": [
"summary(model, (3, 16, 112, 112))"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv3d-1 [-1, 64, 16, 112, 112] 5,248\n",
" Conv3d-2 [-1, 128, 16, 56, 56] 221,312\n",
" Conv3d-3 [-1, 256, 8, 28, 28] 884,992\n",
" Conv3d-4 [-1, 256, 8, 28, 28] 1,769,728\n",
" Conv3d-5 [-1, 512, 4, 14, 14] 3,539,456\n",
" Conv3d-6 [-1, 512, 4, 14, 14] 7,078,400\n",
" Conv3d-7 [-1, 512, 2, 7, 7] 7,078,400\n",
" Conv3d-8 [-1, 512, 2, 7, 7] 7,078,400\n",
" Linear-9 [-1, 4096] 33,558,528\n",
" Linear-10 [-1, 4096] 16,781,312\n",
" Linear-11 [-1, 101] 413,797\n",
"================================================================\n",
"Total params: 78,409,573\n",
"Trainable params: 78,409,573\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 2.30\n",
"Forward/backward pass size (MB): 178.45\n",
"Params size (MB): 299.11\n",
"Estimated Total Size (MB): 479.86\n",
"----------------------------------------------------------------\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WxyLOOgg2xtv",
"colab_type": "code",
"colab": {}
},
"source": [
"class RandomDataset(object):\n",
" def __init__(self, num_classes, num_data, transforms=None):\n",
" self.num_data = num_data\n",
" self.X = np.random.rand(num_data, 16, 128, 171, 3)\n",
" self.y = np.random.randint(0, num_classes, (num_data,))\n",
" self.transforms = transforms\n",
"\n",
" def __getitem__(self, idx):\n",
" X = self.X[idx]\n",
" y = self.y[idx]\n",
" # transform\n",
" X = self.transforms(X)\n",
" return X, y\n",
"\n",
" def __len__(self):\n",
" return self.num_data"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vbZ7yZkl8jaK",
"colab_type": "code",
"colab": {}
},
"source": [
"\"\"\"\n",
"Hyperparameters\n",
"\"\"\"\n",
"num_classes = 101\n",
"num_epochs = 3\n",
"batch_size = 8\n",
"\n",
"# optimizer\n",
"lr = 1e-2\n",
"\n",
"# scheduler\n",
"step_size = 5\n",
"gamma = 0.1\n",
"\n",
"# runner\n",
"logdir = \"./logs\""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "s4iuQ2WxBS0Z",
"colab_type": "code",
"colab": {}
},
"source": [
"def apply_tfms_video(video, tfms_albu):\n",
" \"\"\"\n",
" Apply Albumentations to Videos\n",
"\n",
" Args:\n",
" video: numpy array (T, H, W, C)\n",
" tfms_albu: albumentations\n",
"\n",
" Returns:\n",
" tensor: pytorch tensor (C, T, H, W)\n",
" \"\"\"\n",
" tfms_seed = random.randint(0, 99999)\n",
" aug_video = []\n",
" for x in video:\n",
" random.seed(tfms_seed)\n",
" aug_video.append((tfms_albu(image = np.asarray(x)))['image'])\n",
" return torch.stack(aug_video).permute(1, 0, 2, 3)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JKjAIpHe8Ec1",
"colab_type": "code",
"colab": {}
},
"source": [
"train_transforms = albu.Compose([\n",
" albu.RandomCrop(112, 112),\n",
" albu.HorizontalFlip(p=0.5),\n",
" albu.Normalize(),\n",
" ToTensor()\n",
"])\n",
"\n",
"test_transforms = albu.Compose([\n",
" albu.CenterCrop(112, 112),\n",
" albu.Normalize(),\n",
" ToTensor()\n",
"])\n",
"\n",
"train_data = RandomDataset(num_classes, 800, transforms=lambda x: apply_tfms_video(x, train_transforms))\n",
"val_data = RandomDataset(num_classes, 100, transforms=lambda x: apply_tfms_video(x, test_transforms))\n",
"test_data = RandomDataset(num_classes, 100, transforms=lambda x: apply_tfms_video(x, test_transforms))\n",
"train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)\n",
"val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)\n",
"test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)\n",
"loaders = {\n",
" \"train\": train_loader,\n",
" \"valid\": val_loader\n",
"}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3r0Ttsvk6iDm",
"colab_type": "code",
"colab": {}
},
"source": [
"optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)\n",
"# optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
"criterion = nn.CrossEntropyLoss()\n",
"runner = SupervisedRunner()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-0MBdV_U8B6l",
"colab_type": "code",
"outputId": "f65960bc-fd0f-4376-c644-a8613de8b04f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 326
}
},
"source": [
"model.train()\n",
"runner.train(\n",
" model=model,\n",
" criterion=criterion,\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
" loaders=loaders,\n",
" callbacks=[AccuracyCallback(num_classes=num_classes)],\n",
" logdir=logdir,\n",
" num_epochs=num_epochs,\n",
" verbose=True\n",
")"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"1/3 * Epoch (train): 100% 100/100 [02:54<00:00, 1.74s/it, accuracy01=0.000e+00, accuracy03=0.000e+00, accuracy05=0.000e+00, loss=4.615]\n",
"1/3 * Epoch (valid): 100% 13/13 [00:06<00:00, 2.02it/s, accuracy01=0.000e+00, accuracy03=0.000e+00, accuracy05=0.000e+00, loss=4.615]\n",
"[2020-03-01 13:12:12,604] \n",
"1/3 * Epoch 1 (train): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=70.9924 | _timers/batch_time=0.1183 | _timers/data_time=0.1162 | _timers/model_time=0.0021 | accuracy01=0.6250 | accuracy03=2.7500 | accuracy05=4.7500 | loss=4.6151\n",
"1/3 * Epoch 1 (valid): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=134.9420 | _timers/batch_time=0.0609 | _timers/data_time=0.0590 | _timers/model_time=0.0019 | accuracy01=0.9615 | accuracy03=1.9231 | accuracy05=2.8846 | loss=4.6151\n",
"2/3 * Epoch (train): 100% 100/100 [02:53<00:00, 1.74s/it, accuracy01=0.000e+00, accuracy03=0.000e+00, accuracy05=0.000e+00, loss=4.615]\n",
"2/3 * Epoch (valid): 100% 13/13 [00:06<00:00, 2.01it/s, accuracy01=0.000e+00, accuracy03=0.000e+00, accuracy05=0.000e+00, loss=4.615]\n",
"[2020-03-01 13:15:49,545] \n",
"2/3 * Epoch 2 (train): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=73.2446 | _timers/batch_time=0.1125 | _timers/data_time=0.1104 | _timers/model_time=0.0020 | accuracy01=0.5000 | accuracy03=2.6250 | accuracy05=4.8750 | loss=4.6151\n",
"2/3 * Epoch 2 (valid): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=137.8087 | _timers/batch_time=0.0601 | _timers/data_time=0.0583 | _timers/model_time=0.0018 | accuracy01=0.9615 | accuracy03=0.9615 | accuracy05=1.9231 | loss=4.6151\n",
"3/3 * Epoch (train): 100% 100/100 [02:55<00:00, 1.76s/it, accuracy01=12.500, accuracy03=12.500, accuracy05=12.500, loss=4.615]\n",
"3/3 * Epoch (valid): 100% 13/13 [00:06<00:00, 1.99it/s, accuracy01=0.000e+00, accuracy03=0.000e+00, accuracy05=0.000e+00, loss=4.615]\n",
"[2020-03-01 13:19:30,670] \n",
"3/3 * Epoch 3 (train): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=60.3787 | _timers/batch_time=0.1342 | _timers/data_time=0.1320 | _timers/model_time=0.0021 | accuracy01=0.7500 | accuracy03=2.5000 | accuracy05=5.5000 | loss=4.6151\n",
"3/3 * Epoch 3 (valid): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=130.9642 | _timers/batch_time=0.0631 | _timers/data_time=0.0612 | _timers/model_time=0.0019 | accuracy01=0.000e+00 | accuracy03=0.9615 | accuracy05=3.8462 | loss=4.6151\n",
"Top best models:\n",
"logs/checkpoints/train.3.pth\t4.6151\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "uIdD8ERw9sNb",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
},
"outputId": "7772b0aa-f3cd-4949-f938-045fff2cb7df"
},
"source": [
"model.eval()\n",
"predictions = runner.predict_loader(\n",
" model=model,\n",
" loader=test_loader,\n",
" resume=f\"{logdir}/checkpoints/best.pth\",\n",
" verbose=True\n",
")"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"=> loading checkpoint ./logs/checkpoints/best.pth\n",
"loaded checkpoint ./logs/checkpoints/best.pth (epoch 3, stage_epoch 3, stage train)\n",
"1/1 * Epoch (infer): 100% 13/13 [00:06<00:00, 2.00it/s]\n",
"Top best models:\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pjZy6Am4K1Zn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "6db4cc59-0a26-40b6-afc3-214ad034bab2"
},
"source": [
"from sklearn.metrics import accuracy_score\n",
"accuracy_score(test_data.y, predictions.argmax(axis=1))"
],
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.01"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xUDfaAg-NBhi",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment