Skip to content

Instantly share code, notes, and snippets.

@okwrtdsh
Created March 1, 2020 13:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save okwrtdsh/949a085cff84daa7e7d2e8baedfc5757 to your computer and use it in GitHub Desktop.
Save okwrtdsh/949a085cff84daa7e7d2e8baedfc5757 to your computer and use it in GitHub Desktop.
train_I3D.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "train_I3D.ipynb",
"provenance": [],
"collapsed_sections": [],
"machine_shape": "hm",
"authorship_tag": "ABX9TyPkYZJvTXBmWBikhmZogUX6",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/okwrtdsh/949a085cff84daa7e7d2e8baedfc5757/train_i3d.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "REzmuopT1Hrx",
"colab_type": "code",
"outputId": "ae0cb982-33d1-4b44-c312-ea3d9d7d9ea1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 938
}
},
"source": [
"!pip install -U torchsummary catalyst"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already up-to-date: torchsummary in /usr/local/lib/python3.6/dist-packages (1.5.1)\n",
"Requirement already up-to-date: catalyst in /usr/local/lib/python3.6/dist-packages (20.2.4)\n",
"Requirement already satisfied, skipping upgrade: GitPython>=2.1.11 in /usr/local/lib/python3.6/dist-packages (from catalyst) (3.1.0)\n",
"Requirement already satisfied, skipping upgrade: imageio in /usr/local/lib/python3.6/dist-packages (from catalyst) (2.4.1)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.16.4 in /usr/local/lib/python3.6/dist-packages (from catalyst) (1.17.5)\n",
"Requirement already satisfied, skipping upgrade: safitty>=1.2.3 in /usr/local/lib/python3.6/dist-packages (from catalyst) (1.3)\n",
"Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.6/dist-packages (from catalyst) (20.1)\n",
"Requirement already satisfied, skipping upgrade: seaborn in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.10.0)\n",
"Requirement already satisfied, skipping upgrade: ipython in /usr/local/lib/python3.6/dist-packages (from catalyst) (5.5.0)\n",
"Requirement already satisfied, skipping upgrade: scikit-learn>=0.20 in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.22.1)\n",
"Requirement already satisfied, skipping upgrade: tensorboardX in /usr/local/lib/python3.6/dist-packages (from catalyst) (2.0)\n",
"Requirement already satisfied, skipping upgrade: PyYAML in /usr/local/lib/python3.6/dist-packages (from catalyst) (3.13)\n",
"Requirement already satisfied, skipping upgrade: scikit-image>=0.14.2 in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.16.2)\n",
"Requirement already satisfied, skipping upgrade: tensorboard>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from catalyst) (1.15.0)\n",
"Requirement already satisfied, skipping upgrade: Pillow<7 in /usr/local/lib/python3.6/dist-packages (from catalyst) (6.2.2)\n",
"Requirement already satisfied, skipping upgrade: torch>=1.0.0 in /usr/local/lib/python3.6/dist-packages (from catalyst) (1.4.0)\n",
"Requirement already satisfied, skipping upgrade: pandas>=0.22 in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.25.3)\n",
"Requirement already satisfied, skipping upgrade: opencv-python in /usr/local/lib/python3.6/dist-packages (from catalyst) (4.1.2.30)\n",
"Requirement already satisfied, skipping upgrade: matplotlib in /usr/local/lib/python3.6/dist-packages (from catalyst) (3.1.3)\n",
"Requirement already satisfied, skipping upgrade: tqdm>=4.33.0 in /usr/local/lib/python3.6/dist-packages (from catalyst) (4.43.0)\n",
"Requirement already satisfied, skipping upgrade: torchvision>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from catalyst) (0.5.0)\n",
"Requirement already satisfied, skipping upgrade: plotly>=4.1.0 in /usr/local/lib/python3.6/dist-packages (from catalyst) (4.4.1)\n",
"Requirement already satisfied, skipping upgrade: crc32c>=1.7 in /usr/local/lib/python3.6/dist-packages (from catalyst) (2.0)\n",
"Requirement already satisfied, skipping upgrade: gitdb<5,>=4.0.1 in /usr/local/lib/python3.6/dist-packages (from GitPython>=2.1.11->catalyst) (4.0.2)\n",
"Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->catalyst) (2.4.6)\n",
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from packaging->catalyst) (1.12.0)\n",
"Requirement already satisfied, skipping upgrade: scipy>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from seaborn->catalyst) (1.4.1)\n",
"Requirement already satisfied, skipping upgrade: pickleshare in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (0.7.5)\n",
"Requirement already satisfied, skipping upgrade: decorator in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (4.4.1)\n",
"Requirement already satisfied, skipping upgrade: traitlets>=4.2 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (4.3.3)\n",
"Requirement already satisfied, skipping upgrade: pexpect; sys_platform != \"win32\" in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (4.8.0)\n",
"Requirement already satisfied, skipping upgrade: simplegeneric>0.8 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (0.8.1)\n",
"Requirement already satisfied, skipping upgrade: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (1.0.18)\n",
"Requirement already satisfied, skipping upgrade: pygments in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (2.1.3)\n",
"Requirement already satisfied, skipping upgrade: setuptools>=18.5 in /usr/local/lib/python3.6/dist-packages (from ipython->catalyst) (45.1.0)\n",
"Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.20->catalyst) (0.14.1)\n",
"Requirement already satisfied, skipping upgrade: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorboardX->catalyst) (3.10.0)\n",
"Requirement already satisfied, skipping upgrade: PyWavelets>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.14.2->catalyst) (1.1.1)\n",
"Requirement already satisfied, skipping upgrade: networkx>=2.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image>=0.14.2->catalyst) (2.4)\n",
"Requirement already satisfied, skipping upgrade: grpcio>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (1.27.1)\n",
"Requirement already satisfied, skipping upgrade: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (0.9.0)\n",
"Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (3.2.1)\n",
"Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (1.0.0)\n",
"Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14.0->catalyst) (0.34.2)\n",
"Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22->catalyst) (2018.9)\n",
"Requirement already satisfied, skipping upgrade: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22->catalyst) (2.6.1)\n",
"Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->catalyst) (1.1.0)\n",
"Requirement already satisfied, skipping upgrade: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->catalyst) (0.10.0)\n",
"Requirement already satisfied, skipping upgrade: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly>=4.1.0->catalyst) (1.3.3)\n",
"Requirement already satisfied, skipping upgrade: smmap<4,>=3.0.1 in /usr/local/lib/python3.6/dist-packages (from gitdb<5,>=4.0.1->GitPython>=2.1.11->catalyst) (3.0.1)\n",
"Requirement already satisfied, skipping upgrade: ipython-genutils in /usr/local/lib/python3.6/dist-packages (from traitlets>=4.2->ipython->catalyst) (0.2.0)\n",
"Requirement already satisfied, skipping upgrade: ptyprocess>=0.5 in /usr/local/lib/python3.6/dist-packages (from pexpect; sys_platform != \"win32\"->ipython->catalyst) (0.6.0)\n",
"Requirement already satisfied, skipping upgrade: wcwidth in /usr/local/lib/python3.6/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython->catalyst) (0.1.8)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LbDQO9bEy5ii",
"colab_type": "code",
"outputId": "399ed0ae-504a-4669-c620-924a78ff0f24",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"\"\"\"\n",
"Mount Google Drive\n",
"\"\"\"\n",
"# from google.colab import drive\n",
"# drive.mount('/content/gdrive')\n",
"# !ls 'gdrive/My Drive/Colab Notebooks/'"
],
"execution_count": 2,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'\\nMount Google\\u3000Drive\\n'"
]
},
"metadata": {
"tags": []
},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8vxOid9x15rp",
"colab_type": "code",
"outputId": "4ddac335-3879-42d3-a4d6-200debabe099",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 63
}
},
"source": [
"\"\"\"\n",
"Reproducibility\n",
"\"\"\"\n",
"SEED = 123\n",
"\n",
"import os\n",
"os.environ['PYTHONHASHSEED'] = '0'\n",
"\n",
"import random\n",
"random.seed(SEED)\n",
"\n",
"import numpy as np\n",
"np.random.seed(SEED)\n",
"\n",
"import torch\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cudnn.benchmark = False\n",
"\n",
"from catalyst.utils import set_global_seed\n",
"set_global_seed(SEED)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<p style=\"color: red;\">\n",
"The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
"We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
"or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
"<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6B3QBlkN6Ppo",
"colab_type": "code",
"outputId": "d61e2e3f-ecad-46a4-fc6c-a5f2fe238a61",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"import math\n",
"\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"import albumentations as albu\n",
"from albumentations.pytorch import ToTensor\n",
"\n",
"from catalyst.dl.runner import SupervisedRunner\n",
"from catalyst.dl.callbacks import AccuracyCallback, EarlyStoppingCallback, InferCallback, CheckpointCallback, MixupCallback\n",
"\n",
"from torchsummary import summary"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"alchemy not available, to install alchemy, run `pip install alchemy-catalyst`.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LMux-J5gu3I1",
"colab_type": "code",
"colab": {}
},
"source": [
"# https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py\n",
"\n",
"\n",
"class MaxPool3dSamePadding(nn.MaxPool3d):\n",
" \n",
" def compute_pad(self, dim, s):\n",
" if s % self.stride[dim] == 0:\n",
" return max(self.kernel_size[dim] - self.stride[dim], 0)\n",
" else:\n",
" return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)\n",
"\n",
" def forward(self, x):\n",
" # compute 'same' padding\n",
" (batch, channel, t, h, w) = x.size()\n",
" #print t,h,w\n",
" out_t = np.ceil(float(t) / float(self.stride[0]))\n",
" out_h = np.ceil(float(h) / float(self.stride[1]))\n",
" out_w = np.ceil(float(w) / float(self.stride[2]))\n",
" #print out_t, out_h, out_w\n",
" pad_t = self.compute_pad(0, t)\n",
" pad_h = self.compute_pad(1, h)\n",
" pad_w = self.compute_pad(2, w)\n",
" #print pad_t, pad_h, pad_w\n",
"\n",
" pad_t_f = pad_t // 2\n",
" pad_t_b = pad_t - pad_t_f\n",
" pad_h_f = pad_h // 2\n",
" pad_h_b = pad_h - pad_h_f\n",
" pad_w_f = pad_w // 2\n",
" pad_w_b = pad_w - pad_w_f\n",
"\n",
" pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)\n",
" #print x.size()\n",
" #print pad\n",
" x = F.pad(x, pad)\n",
" return super(MaxPool3dSamePadding, self).forward(x)\n",
" \n",
"\n",
"class Unit3D(nn.Module):\n",
"\n",
" def __init__(self, in_channels,\n",
" output_channels,\n",
" kernel_shape=(1, 1, 1),\n",
" stride=(1, 1, 1),\n",
" padding=0,\n",
" activation_fn=F.relu,\n",
" use_batch_norm=True,\n",
" use_bias=False,\n",
" name='unit_3d'):\n",
" \n",
" \"\"\"Initializes Unit3D module.\"\"\"\n",
" super(Unit3D, self).__init__()\n",
" \n",
" self._output_channels = output_channels\n",
" self._kernel_shape = kernel_shape\n",
" self._stride = stride\n",
" self._use_batch_norm = use_batch_norm\n",
" self._activation_fn = activation_fn\n",
" self._use_bias = use_bias\n",
" self.name = name\n",
" self.padding = padding\n",
" \n",
" self.conv3d = nn.Conv3d(in_channels=in_channels,\n",
" out_channels=self._output_channels,\n",
" kernel_size=self._kernel_shape,\n",
" stride=self._stride,\n",
" padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function\n",
" bias=self._use_bias)\n",
" \n",
" if self._use_batch_norm:\n",
" self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)\n",
"\n",
" def compute_pad(self, dim, s):\n",
" if s % self._stride[dim] == 0:\n",
" return max(self._kernel_shape[dim] - self._stride[dim], 0)\n",
" else:\n",
" return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)\n",
"\n",
" \n",
" def forward(self, x):\n",
" # compute 'same' padding\n",
" (batch, channel, t, h, w) = x.size()\n",
" #print t,h,w\n",
" out_t = np.ceil(float(t) / float(self._stride[0]))\n",
" out_h = np.ceil(float(h) / float(self._stride[1]))\n",
" out_w = np.ceil(float(w) / float(self._stride[2]))\n",
" #print out_t, out_h, out_w\n",
" pad_t = self.compute_pad(0, t)\n",
" pad_h = self.compute_pad(1, h)\n",
" pad_w = self.compute_pad(2, w)\n",
" #print pad_t, pad_h, pad_w\n",
"\n",
" pad_t_f = pad_t // 2\n",
" pad_t_b = pad_t - pad_t_f\n",
" pad_h_f = pad_h // 2\n",
" pad_h_b = pad_h - pad_h_f\n",
" pad_w_f = pad_w // 2\n",
" pad_w_b = pad_w - pad_w_f\n",
"\n",
" pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)\n",
" #print x.size()\n",
" #print pad\n",
" x = F.pad(x, pad)\n",
" #print x.size() \n",
"\n",
" x = self.conv3d(x)\n",
" if self._use_batch_norm:\n",
" x = self.bn(x)\n",
" if self._activation_fn is not None:\n",
" x = self._activation_fn(x)\n",
" return x\n",
"\n",
"\n",
"\n",
"class InceptionModule(nn.Module):\n",
" def __init__(self, in_channels, out_channels, name):\n",
" super(InceptionModule, self).__init__()\n",
"\n",
" self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,\n",
" name=name+'/Branch_0/Conv3d_0a_1x1')\n",
" self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,\n",
" name=name+'/Branch_1/Conv3d_0a_1x1')\n",
" self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],\n",
" name=name+'/Branch_1/Conv3d_0b_3x3')\n",
" self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,\n",
" name=name+'/Branch_2/Conv3d_0a_1x1')\n",
" self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],\n",
" name=name+'/Branch_2/Conv3d_0b_3x3')\n",
" self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],\n",
" stride=(1, 1, 1), padding=0)\n",
" self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,\n",
" name=name+'/Branch_3/Conv3d_0b_1x1')\n",
" self.name = name\n",
"\n",
" def forward(self, x): \n",
" b0 = self.b0(x)\n",
" b1 = self.b1b(self.b1a(x))\n",
" b2 = self.b2b(self.b2a(x))\n",
" b3 = self.b3b(self.b3a(x))\n",
" return torch.cat([b0,b1,b2,b3], dim=1)\n",
"\n",
"\n",
"class InceptionI3d(nn.Module):\n",
" \"\"\"Inception-v1 I3D architecture.\n",
" The model is introduced in:\n",
" Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset\n",
" Joao Carreira, Andrew Zisserman\n",
" https://arxiv.org/pdf/1705.07750v1.pdf.\n",
" See also the Inception architecture, introduced in:\n",
" Going deeper with convolutions\n",
" Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,\n",
" Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.\n",
" http://arxiv.org/pdf/1409.4842v1.pdf.\n",
" \"\"\"\n",
"\n",
" # Endpoints of the model in order. During construction, all the endpoints up\n",
" # to a designated `final_endpoint` are returned in a dictionary as the\n",
" # second return value.\n",
" VALID_ENDPOINTS = (\n",
" 'Conv3d_1a_7x7',\n",
" 'MaxPool3d_2a_3x3',\n",
" 'Conv3d_2b_1x1',\n",
" 'Conv3d_2c_3x3',\n",
" 'MaxPool3d_3a_3x3',\n",
" 'Mixed_3b',\n",
" 'Mixed_3c',\n",
" 'MaxPool3d_4a_3x3',\n",
" 'Mixed_4b',\n",
" 'Mixed_4c',\n",
" 'Mixed_4d',\n",
" 'Mixed_4e',\n",
" 'Mixed_4f',\n",
" 'MaxPool3d_5a_2x2',\n",
" 'Mixed_5b',\n",
" 'Mixed_5c',\n",
" 'Logits',\n",
" 'Predictions',\n",
" )\n",
"\n",
" def __init__(self, num_classes=400, spatial_squeeze=True,\n",
" final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):\n",
" \"\"\"Initializes I3D model instance.\n",
" Args:\n",
" num_classes: The number of outputs in the logit layer (default 400, which\n",
" matches the Kinetics dataset).\n",
" spatial_squeeze: Whether to squeeze the spatial dimensions for the logits\n",
" before returning (default True).\n",
" final_endpoint: The model contains many possible endpoints.\n",
" `final_endpoint` specifies the last endpoint for the model to be built\n",
" up to. In addition to the output at `final_endpoint`, all the outputs\n",
" at endpoints up to `final_endpoint` will also be returned, in a\n",
" dictionary. `final_endpoint` must be one of\n",
" InceptionI3d.VALID_ENDPOINTS (default 'Logits').\n",
" name: A string (optional). The name of this module.\n",
" Raises:\n",
" ValueError: if `final_endpoint` is not recognized.\n",
" \"\"\"\n",
"\n",
" if final_endpoint not in self.VALID_ENDPOINTS:\n",
" raise ValueError('Unknown final endpoint %s' % final_endpoint)\n",
"\n",
" super(InceptionI3d, self).__init__()\n",
" self._num_classes = num_classes\n",
" self._spatial_squeeze = spatial_squeeze\n",
" self._final_endpoint = final_endpoint\n",
" self.logits = None\n",
"\n",
" if self._final_endpoint not in self.VALID_ENDPOINTS:\n",
" raise ValueError('Unknown final endpoint %s' % self._final_endpoint)\n",
"\n",
" self.end_points = {}\n",
" end_point = 'Conv3d_1a_7x7'\n",
" self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],\n",
" stride=(2, 2, 2), padding=(3,3,3), name=name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
" \n",
" end_point = 'MaxPool3d_2a_3x3'\n",
" self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),\n",
" padding=0)\n",
" if self._final_endpoint == end_point: return\n",
" \n",
" end_point = 'Conv3d_2b_1x1'\n",
" self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,\n",
" name=name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
" \n",
" end_point = 'Conv3d_2c_3x3'\n",
" self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,\n",
" name=name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'MaxPool3d_3a_3x3'\n",
" self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),\n",
" padding=0)\n",
" if self._final_endpoint == end_point: return\n",
" \n",
" end_point = 'Mixed_3b'\n",
" self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Mixed_3c'\n",
" self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'MaxPool3d_4a_3x3'\n",
" self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),\n",
" padding=0)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Mixed_4b'\n",
" self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Mixed_4c'\n",
" self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Mixed_4d'\n",
" self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Mixed_4e'\n",
" self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Mixed_4f'\n",
" self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'MaxPool3d_5a_2x2'\n",
" self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),\n",
" padding=0)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Mixed_5b'\n",
" self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Mixed_5c'\n",
" self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)\n",
" if self._final_endpoint == end_point: return\n",
"\n",
" end_point = 'Logits'\n",
" self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],\n",
" stride=(1, 1, 1))\n",
" self.dropout = nn.Dropout(dropout_keep_prob)\n",
" self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,\n",
" kernel_shape=[1, 1, 1],\n",
" padding=0,\n",
" activation_fn=None,\n",
" use_batch_norm=False,\n",
" use_bias=True,\n",
" name='logits')\n",
"\n",
" self.build()\n",
"\n",
"\n",
" def replace_logits(self, num_classes):\n",
" self._num_classes = num_classes\n",
" self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,\n",
" kernel_shape=[1, 1, 1],\n",
" padding=0,\n",
" activation_fn=None,\n",
" use_batch_norm=False,\n",
" use_bias=True,\n",
" name='logits')\n",
" \n",
" \n",
" def build(self):\n",
" for k in self.end_points.keys():\n",
" self.add_module(k, self.end_points[k])\n",
" \n",
" def forward(self, x):\n",
" for end_point in self.VALID_ENDPOINTS:\n",
" if end_point in self.end_points:\n",
" x = self._modules[end_point](x) # use _modules to work with dataparallel\n",
"\n",
" x = self.logits(self.dropout(self.avg_pool(x)))\n",
" if self._spatial_squeeze:\n",
" logits = x.squeeze(3).squeeze(3)\n",
" # logits is batch X time X classes, which is what we want to work with\n",
" return logits.view(logits.size()[0], -1)\n",
" \n",
"\n",
" def extract_features(self, x):\n",
" for end_point in self.VALID_ENDPOINTS:\n",
" if end_point in self.end_points:\n",
" x = self._modules[end_point](x)\n",
" return self.avg_pool(x)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8q6IO6o61Oor",
"colab_type": "code",
"colab": {}
},
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = InceptionI3d(101).to(device)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VOmvCqljyk5c",
"colab_type": "code",
"outputId": "489f5667-8b27-48f2-9af2-53a469b85092",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"summary(model, (3, 16, 224, 224))"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv3d-1 [-1, 64, 8, 112, 112] 65,856\n",
" BatchNorm3d-2 [-1, 64, 8, 112, 112] 128\n",
" Unit3D-3 [-1, 64, 8, 112, 112] 0\n",
"MaxPool3dSamePadding-4 [-1, 64, 8, 56, 56] 0\n",
" Conv3d-5 [-1, 64, 8, 56, 56] 4,096\n",
" BatchNorm3d-6 [-1, 64, 8, 56, 56] 128\n",
" Unit3D-7 [-1, 64, 8, 56, 56] 0\n",
" Conv3d-8 [-1, 192, 8, 56, 56] 331,776\n",
" BatchNorm3d-9 [-1, 192, 8, 56, 56] 384\n",
" Unit3D-10 [-1, 192, 8, 56, 56] 0\n",
"MaxPool3dSamePadding-11 [-1, 192, 8, 28, 28] 0\n",
" Conv3d-12 [-1, 64, 8, 28, 28] 12,288\n",
" BatchNorm3d-13 [-1, 64, 8, 28, 28] 128\n",
" Unit3D-14 [-1, 64, 8, 28, 28] 0\n",
" Conv3d-15 [-1, 96, 8, 28, 28] 18,432\n",
" BatchNorm3d-16 [-1, 96, 8, 28, 28] 192\n",
" Unit3D-17 [-1, 96, 8, 28, 28] 0\n",
" Conv3d-18 [-1, 128, 8, 28, 28] 331,776\n",
" BatchNorm3d-19 [-1, 128, 8, 28, 28] 256\n",
" Unit3D-20 [-1, 128, 8, 28, 28] 0\n",
" Conv3d-21 [-1, 16, 8, 28, 28] 3,072\n",
" BatchNorm3d-22 [-1, 16, 8, 28, 28] 32\n",
" Unit3D-23 [-1, 16, 8, 28, 28] 0\n",
" Conv3d-24 [-1, 32, 8, 28, 28] 13,824\n",
" BatchNorm3d-25 [-1, 32, 8, 28, 28] 64\n",
" Unit3D-26 [-1, 32, 8, 28, 28] 0\n",
"MaxPool3dSamePadding-27 [-1, 192, 8, 28, 28] 0\n",
" Conv3d-28 [-1, 32, 8, 28, 28] 6,144\n",
" BatchNorm3d-29 [-1, 32, 8, 28, 28] 64\n",
" Unit3D-30 [-1, 32, 8, 28, 28] 0\n",
" InceptionModule-31 [-1, 256, 8, 28, 28] 0\n",
" Conv3d-32 [-1, 128, 8, 28, 28] 32,768\n",
" BatchNorm3d-33 [-1, 128, 8, 28, 28] 256\n",
" Unit3D-34 [-1, 128, 8, 28, 28] 0\n",
" Conv3d-35 [-1, 128, 8, 28, 28] 32,768\n",
" BatchNorm3d-36 [-1, 128, 8, 28, 28] 256\n",
" Unit3D-37 [-1, 128, 8, 28, 28] 0\n",
" Conv3d-38 [-1, 192, 8, 28, 28] 663,552\n",
" BatchNorm3d-39 [-1, 192, 8, 28, 28] 384\n",
" Unit3D-40 [-1, 192, 8, 28, 28] 0\n",
" Conv3d-41 [-1, 32, 8, 28, 28] 8,192\n",
" BatchNorm3d-42 [-1, 32, 8, 28, 28] 64\n",
" Unit3D-43 [-1, 32, 8, 28, 28] 0\n",
" Conv3d-44 [-1, 96, 8, 28, 28] 82,944\n",
" BatchNorm3d-45 [-1, 96, 8, 28, 28] 192\n",
" Unit3D-46 [-1, 96, 8, 28, 28] 0\n",
"MaxPool3dSamePadding-47 [-1, 256, 8, 28, 28] 0\n",
" Conv3d-48 [-1, 64, 8, 28, 28] 16,384\n",
" BatchNorm3d-49 [-1, 64, 8, 28, 28] 128\n",
" Unit3D-50 [-1, 64, 8, 28, 28] 0\n",
" InceptionModule-51 [-1, 480, 8, 28, 28] 0\n",
"MaxPool3dSamePadding-52 [-1, 480, 4, 14, 14] 0\n",
" Conv3d-53 [-1, 192, 4, 14, 14] 92,160\n",
" BatchNorm3d-54 [-1, 192, 4, 14, 14] 384\n",
" Unit3D-55 [-1, 192, 4, 14, 14] 0\n",
" Conv3d-56 [-1, 96, 4, 14, 14] 46,080\n",
" BatchNorm3d-57 [-1, 96, 4, 14, 14] 192\n",
" Unit3D-58 [-1, 96, 4, 14, 14] 0\n",
" Conv3d-59 [-1, 208, 4, 14, 14] 539,136\n",
" BatchNorm3d-60 [-1, 208, 4, 14, 14] 416\n",
" Unit3D-61 [-1, 208, 4, 14, 14] 0\n",
" Conv3d-62 [-1, 16, 4, 14, 14] 7,680\n",
" BatchNorm3d-63 [-1, 16, 4, 14, 14] 32\n",
" Unit3D-64 [-1, 16, 4, 14, 14] 0\n",
" Conv3d-65 [-1, 48, 4, 14, 14] 20,736\n",
" BatchNorm3d-66 [-1, 48, 4, 14, 14] 96\n",
" Unit3D-67 [-1, 48, 4, 14, 14] 0\n",
"MaxPool3dSamePadding-68 [-1, 480, 4, 14, 14] 0\n",
" Conv3d-69 [-1, 64, 4, 14, 14] 30,720\n",
" BatchNorm3d-70 [-1, 64, 4, 14, 14] 128\n",
" Unit3D-71 [-1, 64, 4, 14, 14] 0\n",
" InceptionModule-72 [-1, 512, 4, 14, 14] 0\n",
" Conv3d-73 [-1, 160, 4, 14, 14] 81,920\n",
" BatchNorm3d-74 [-1, 160, 4, 14, 14] 320\n",
" Unit3D-75 [-1, 160, 4, 14, 14] 0\n",
" Conv3d-76 [-1, 112, 4, 14, 14] 57,344\n",
" BatchNorm3d-77 [-1, 112, 4, 14, 14] 224\n",
" Unit3D-78 [-1, 112, 4, 14, 14] 0\n",
" Conv3d-79 [-1, 224, 4, 14, 14] 677,376\n",
" BatchNorm3d-80 [-1, 224, 4, 14, 14] 448\n",
" Unit3D-81 [-1, 224, 4, 14, 14] 0\n",
" Conv3d-82 [-1, 24, 4, 14, 14] 12,288\n",
" BatchNorm3d-83 [-1, 24, 4, 14, 14] 48\n",
" Unit3D-84 [-1, 24, 4, 14, 14] 0\n",
" Conv3d-85 [-1, 64, 4, 14, 14] 41,472\n",
" BatchNorm3d-86 [-1, 64, 4, 14, 14] 128\n",
" Unit3D-87 [-1, 64, 4, 14, 14] 0\n",
"MaxPool3dSamePadding-88 [-1, 512, 4, 14, 14] 0\n",
" Conv3d-89 [-1, 64, 4, 14, 14] 32,768\n",
" BatchNorm3d-90 [-1, 64, 4, 14, 14] 128\n",
" Unit3D-91 [-1, 64, 4, 14, 14] 0\n",
" InceptionModule-92 [-1, 512, 4, 14, 14] 0\n",
" Conv3d-93 [-1, 128, 4, 14, 14] 65,536\n",
" BatchNorm3d-94 [-1, 128, 4, 14, 14] 256\n",
" Unit3D-95 [-1, 128, 4, 14, 14] 0\n",
" Conv3d-96 [-1, 128, 4, 14, 14] 65,536\n",
" BatchNorm3d-97 [-1, 128, 4, 14, 14] 256\n",
" Unit3D-98 [-1, 128, 4, 14, 14] 0\n",
" Conv3d-99 [-1, 256, 4, 14, 14] 884,736\n",
" BatchNorm3d-100 [-1, 256, 4, 14, 14] 512\n",
" Unit3D-101 [-1, 256, 4, 14, 14] 0\n",
" Conv3d-102 [-1, 24, 4, 14, 14] 12,288\n",
" BatchNorm3d-103 [-1, 24, 4, 14, 14] 48\n",
" Unit3D-104 [-1, 24, 4, 14, 14] 0\n",
" Conv3d-105 [-1, 64, 4, 14, 14] 41,472\n",
" BatchNorm3d-106 [-1, 64, 4, 14, 14] 128\n",
" Unit3D-107 [-1, 64, 4, 14, 14] 0\n",
"MaxPool3dSamePadding-108 [-1, 512, 4, 14, 14] 0\n",
" Conv3d-109 [-1, 64, 4, 14, 14] 32,768\n",
" BatchNorm3d-110 [-1, 64, 4, 14, 14] 128\n",
" Unit3D-111 [-1, 64, 4, 14, 14] 0\n",
" InceptionModule-112 [-1, 512, 4, 14, 14] 0\n",
" Conv3d-113 [-1, 112, 4, 14, 14] 57,344\n",
" BatchNorm3d-114 [-1, 112, 4, 14, 14] 224\n",
" Unit3D-115 [-1, 112, 4, 14, 14] 0\n",
" Conv3d-116 [-1, 144, 4, 14, 14] 73,728\n",
" BatchNorm3d-117 [-1, 144, 4, 14, 14] 288\n",
" Unit3D-118 [-1, 144, 4, 14, 14] 0\n",
" Conv3d-119 [-1, 288, 4, 14, 14] 1,119,744\n",
" BatchNorm3d-120 [-1, 288, 4, 14, 14] 576\n",
" Unit3D-121 [-1, 288, 4, 14, 14] 0\n",
" Conv3d-122 [-1, 32, 4, 14, 14] 16,384\n",
" BatchNorm3d-123 [-1, 32, 4, 14, 14] 64\n",
" Unit3D-124 [-1, 32, 4, 14, 14] 0\n",
" Conv3d-125 [-1, 64, 4, 14, 14] 55,296\n",
" BatchNorm3d-126 [-1, 64, 4, 14, 14] 128\n",
" Unit3D-127 [-1, 64, 4, 14, 14] 0\n",
"MaxPool3dSamePadding-128 [-1, 512, 4, 14, 14] 0\n",
" Conv3d-129 [-1, 64, 4, 14, 14] 32,768\n",
" BatchNorm3d-130 [-1, 64, 4, 14, 14] 128\n",
" Unit3D-131 [-1, 64, 4, 14, 14] 0\n",
" InceptionModule-132 [-1, 528, 4, 14, 14] 0\n",
" Conv3d-133 [-1, 256, 4, 14, 14] 135,168\n",
" BatchNorm3d-134 [-1, 256, 4, 14, 14] 512\n",
" Unit3D-135 [-1, 256, 4, 14, 14] 0\n",
" Conv3d-136 [-1, 160, 4, 14, 14] 84,480\n",
" BatchNorm3d-137 [-1, 160, 4, 14, 14] 320\n",
" Unit3D-138 [-1, 160, 4, 14, 14] 0\n",
" Conv3d-139 [-1, 320, 4, 14, 14] 1,382,400\n",
" BatchNorm3d-140 [-1, 320, 4, 14, 14] 640\n",
" Unit3D-141 [-1, 320, 4, 14, 14] 0\n",
" Conv3d-142 [-1, 32, 4, 14, 14] 16,896\n",
" BatchNorm3d-143 [-1, 32, 4, 14, 14] 64\n",
" Unit3D-144 [-1, 32, 4, 14, 14] 0\n",
" Conv3d-145 [-1, 128, 4, 14, 14] 110,592\n",
" BatchNorm3d-146 [-1, 128, 4, 14, 14] 256\n",
" Unit3D-147 [-1, 128, 4, 14, 14] 0\n",
"MaxPool3dSamePadding-148 [-1, 528, 4, 14, 14] 0\n",
" Conv3d-149 [-1, 128, 4, 14, 14] 67,584\n",
" BatchNorm3d-150 [-1, 128, 4, 14, 14] 256\n",
" Unit3D-151 [-1, 128, 4, 14, 14] 0\n",
" InceptionModule-152 [-1, 832, 4, 14, 14] 0\n",
"MaxPool3dSamePadding-153 [-1, 832, 2, 7, 7] 0\n",
" Conv3d-154 [-1, 256, 2, 7, 7] 212,992\n",
" BatchNorm3d-155 [-1, 256, 2, 7, 7] 512\n",
" Unit3D-156 [-1, 256, 2, 7, 7] 0\n",
" Conv3d-157 [-1, 160, 2, 7, 7] 133,120\n",
" BatchNorm3d-158 [-1, 160, 2, 7, 7] 320\n",
" Unit3D-159 [-1, 160, 2, 7, 7] 0\n",
" Conv3d-160 [-1, 320, 2, 7, 7] 1,382,400\n",
" BatchNorm3d-161 [-1, 320, 2, 7, 7] 640\n",
" Unit3D-162 [-1, 320, 2, 7, 7] 0\n",
" Conv3d-163 [-1, 32, 2, 7, 7] 26,624\n",
" BatchNorm3d-164 [-1, 32, 2, 7, 7] 64\n",
" Unit3D-165 [-1, 32, 2, 7, 7] 0\n",
" Conv3d-166 [-1, 128, 2, 7, 7] 110,592\n",
" BatchNorm3d-167 [-1, 128, 2, 7, 7] 256\n",
" Unit3D-168 [-1, 128, 2, 7, 7] 0\n",
"MaxPool3dSamePadding-169 [-1, 832, 2, 7, 7] 0\n",
" Conv3d-170 [-1, 128, 2, 7, 7] 106,496\n",
" BatchNorm3d-171 [-1, 128, 2, 7, 7] 256\n",
" Unit3D-172 [-1, 128, 2, 7, 7] 0\n",
" InceptionModule-173 [-1, 832, 2, 7, 7] 0\n",
" Conv3d-174 [-1, 384, 2, 7, 7] 319,488\n",
" BatchNorm3d-175 [-1, 384, 2, 7, 7] 768\n",
" Unit3D-176 [-1, 384, 2, 7, 7] 0\n",
" Conv3d-177 [-1, 192, 2, 7, 7] 159,744\n",
" BatchNorm3d-178 [-1, 192, 2, 7, 7] 384\n",
" Unit3D-179 [-1, 192, 2, 7, 7] 0\n",
" Conv3d-180 [-1, 384, 2, 7, 7] 1,990,656\n",
" BatchNorm3d-181 [-1, 384, 2, 7, 7] 768\n",
" Unit3D-182 [-1, 384, 2, 7, 7] 0\n",
" Conv3d-183 [-1, 48, 2, 7, 7] 39,936\n",
" BatchNorm3d-184 [-1, 48, 2, 7, 7] 96\n",
" Unit3D-185 [-1, 48, 2, 7, 7] 0\n",
" Conv3d-186 [-1, 128, 2, 7, 7] 165,888\n",
" BatchNorm3d-187 [-1, 128, 2, 7, 7] 256\n",
" Unit3D-188 [-1, 128, 2, 7, 7] 0\n",
"MaxPool3dSamePadding-189 [-1, 832, 2, 7, 7] 0\n",
" Conv3d-190 [-1, 128, 2, 7, 7] 106,496\n",
" BatchNorm3d-191 [-1, 128, 2, 7, 7] 256\n",
" Unit3D-192 [-1, 128, 2, 7, 7] 0\n",
" InceptionModule-193 [-1, 1024, 2, 7, 7] 0\n",
" AvgPool3d-194 [-1, 1024, 1, 1, 1] 0\n",
" Dropout-195 [-1, 1024, 1, 1, 1] 0\n",
" Conv3d-196 [-1, 101, 1, 1, 1] 103,525\n",
" Unit3D-197 [-1, 101, 1, 1, 1] 0\n",
"================================================================\n",
"Total params: 12,390,789\n",
"Trainable params: 12,390,789\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 9.19\n",
"Forward/backward pass size (MB): 626.36\n",
"Params size (MB): 47.27\n",
"Estimated Total Size (MB): 682.81\n",
"----------------------------------------------------------------\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WxyLOOgg2xtv",
"colab_type": "code",
"colab": {}
},
"source": [
"class RandomDataset(object):\n",
" def __init__(self, num_classes, num_data, transforms=None):\n",
" self.num_data = num_data\n",
" self.X = np.random.rand(num_data, 16, 256, 256, 3)\n",
" self.y = np.random.randint(0, num_classes, (num_data,))\n",
" self.transforms = transforms\n",
"\n",
" def __getitem__(self, idx):\n",
" X = self.X[idx]\n",
" y = self.y[idx]\n",
" # transform\n",
" X = self.transforms(X)\n",
" return X, y\n",
"\n",
" def __len__(self):\n",
" return self.num_data"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vbZ7yZkl8jaK",
"colab_type": "code",
"colab": {}
},
"source": [
"\"\"\"\n",
"Hyperparameters\n",
"\"\"\"\n",
"num_classes = 101\n",
"num_epochs = 3\n",
"batch_size = 8\n",
"\n",
"# optimizer\n",
"lr = 1e-2\n",
"\n",
"# scheduler\n",
"step_size = 5\n",
"gamma = 0.1\n",
"\n",
"# runner\n",
"logdir = \"./logs\""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "s4iuQ2WxBS0Z",
"colab_type": "code",
"colab": {}
},
"source": [
"def apply_tfms_video(video, tfms_albu):\n",
" \"\"\"\n",
" Apply Albumentations to Videos\n",
"\n",
" Args:\n",
" video: numpy array (T, H, W, C)\n",
" tfms_albu: albumentations\n",
"\n",
" Returns:\n",
" tensor: pytorch tensor (C, T, H, W)\n",
" \"\"\"\n",
" tfms_seed = random.randint(0, 99999)\n",
" aug_video = []\n",
" for x in video:\n",
" random.seed(tfms_seed)\n",
" aug_video.append((tfms_albu(image = np.asarray(x)))['image'])\n",
" return torch.stack(aug_video).permute(1, 0, 2, 3)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JKjAIpHe8Ec1",
"colab_type": "code",
"colab": {}
},
"source": [
"train_transforms = albu.Compose([\n",
" albu.RandomCrop(224, 224),\n",
" albu.HorizontalFlip(p=0.5),\n",
" albu.Normalize(),\n",
" ToTensor()\n",
"])\n",
"\n",
"test_transforms = albu.Compose([\n",
" albu.CenterCrop(224, 224),\n",
" albu.Normalize(),\n",
" ToTensor()\n",
"])\n",
"\n",
"train_data = RandomDataset(num_classes, 400, transforms=lambda x: apply_tfms_video(x, train_transforms))\n",
"val_data = RandomDataset(num_classes, 50, transforms=lambda x: apply_tfms_video(x, test_transforms))\n",
"test_data = RandomDataset(num_classes, 50, transforms=lambda x: apply_tfms_video(x, test_transforms))\n",
"train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)\n",
"val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)\n",
"test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)\n",
"loaders = {\n",
" \"train\": train_loader,\n",
" \"valid\": val_loader\n",
"}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3r0Ttsvk6iDm",
"colab_type": "code",
"colab": {}
},
"source": [
"optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)\n",
"# optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
"criterion = nn.CrossEntropyLoss()\n",
"runner = SupervisedRunner()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-0MBdV_U8B6l",
"colab_type": "code",
"outputId": "033f0db1-74b7-4d45-a425-37c8d508aa3f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 326
}
},
"source": [
"model.train()\n",
"runner.train(\n",
" model=model,\n",
" criterion=criterion,\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
" loaders=loaders,\n",
" callbacks=[AccuracyCallback(num_classes=num_classes)],\n",
" logdir=logdir,\n",
" num_epochs=num_epochs,\n",
" verbose=True\n",
")"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"1/3 * Epoch (train): 100% 50/50 [00:26<00:00, 1.90it/s, accuracy01=0.000e+00, accuracy03=0.000e+00, accuracy05=12.500, loss=4.472]\n",
"1/3 * Epoch (valid): 100% 7/7 [00:01<00:00, 3.97it/s, accuracy01=0.000e+00, accuracy03=50.000, accuracy05=50.000, loss=4.565]\n",
"[2020-03-01 13:37:30,602] \n",
"1/3 * Epoch 1 (train): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=37.1111 | _timers/batch_time=0.2168 | _timers/data_time=0.1989 | _timers/model_time=0.0179 | accuracy01=0.7500 | accuracy03=3.5000 | accuracy05=6.0000 | loss=4.8286\n",
"1/3 * Epoch 1 (valid): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=58.6281 | _timers/batch_time=0.1636 | _timers/data_time=0.1517 | _timers/model_time=0.0118 | accuracy01=1.7857 | accuracy03=8.9286 | accuracy05=8.9286 | loss=4.6149\n",
"2/3 * Epoch (train): 100% 50/50 [00:26<00:00, 1.92it/s, accuracy01=12.500, accuracy03=12.500, accuracy05=12.500, loss=4.839]\n",
"2/3 * Epoch (valid): 100% 7/7 [00:01<00:00, 3.94it/s, accuracy01=0.000e+00, accuracy03=0.000e+00, accuracy05=0.000e+00, loss=4.655]\n",
"[2020-03-01 13:37:58,869] \n",
"2/3 * Epoch 2 (train): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=37.7737 | _timers/batch_time=0.2126 | _timers/data_time=0.1951 | _timers/model_time=0.0174 | accuracy01=1.2500 | accuracy03=3.7500 | accuracy05=7.7500 | loss=4.7040\n",
"2/3 * Epoch 2 (valid): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=58.1625 | _timers/batch_time=0.1661 | _timers/data_time=0.1537 | _timers/model_time=0.0123 | accuracy01=0.000e+00 | accuracy03=5.3571 | accuracy05=7.1429 | loss=4.6167\n",
"3/3 * Epoch (train): 100% 50/50 [00:26<00:00, 1.89it/s, accuracy01=0.000e+00, accuracy03=0.000e+00, accuracy05=0.000e+00, loss=4.933]\n",
"3/3 * Epoch (valid): 100% 7/7 [00:01<00:00, 3.97it/s, accuracy01=0.000e+00, accuracy03=50.000, accuracy05=50.000, loss=4.542]\n",
"[2020-03-01 13:38:27,844] \n",
"3/3 * Epoch 3 (train): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=36.2762 | _timers/batch_time=0.2211 | _timers/data_time=0.2034 | _timers/model_time=0.0176 | accuracy01=0.7500 | accuracy03=3.5000 | accuracy05=6.2500 | loss=4.6603\n",
"3/3 * Epoch 3 (valid): _base/lr=0.0100 | _base/momentum=0.9000 | _timers/_fps=59.2029 | _timers/batch_time=0.1650 | _timers/data_time=0.1521 | _timers/model_time=0.0128 | accuracy01=0.000e+00 | accuracy03=10.7143 | accuracy05=12.5000 | loss=4.5996\n",
"Top best models:\n",
"logs/checkpoints/train.3.pth\t4.5996\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "uIdD8ERw9sNb",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
},
"outputId": "860782df-8200-4f16-d209-246163ea1e2a"
},
"source": [
"model.eval()\n",
"predictions = runner.predict_loader(\n",
" model=model,\n",
" loader=test_loader,\n",
" resume=f\"{logdir}/checkpoints/best.pth\",\n",
" verbose=True\n",
")"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"=> loading checkpoint ./logs/checkpoints/best.pth\n",
"loaded checkpoint ./logs/checkpoints/best.pth (epoch 3, stage_epoch 3, stage train)\n",
"1/1 * Epoch (infer): 100% 7/7 [00:01<00:00, 4.06it/s]\n",
"Top best models:\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pjZy6Am4K1Zn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "58400b41-a765-49e4-c555-af11a13dd9fe"
},
"source": [
"from sklearn.metrics import accuracy_score\n",
"accuracy_score(test_data.y, predictions.argmax(axis=1))"
],
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xUDfaAg-NBhi",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment