-
-
Save Strady123/949f3d3ac8257c1e8e168a08460d1b54 to your computer and use it in GitHub Desktop.
Emotion Recognition.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Emotion Recognition.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Strady123/949f3d3ac8257c1e8e168a08460d1b54/emotion-recognition.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0cxljsVPVJ7S" | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch import optim\n", | |
"import torch.nn.functional as F\n", | |
"from torchvision import datasets, transforms, models" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"resources": { | |
"http://localhost:8080/nbextensions/google.colab/files.js": { | |
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", | |
"ok": true, | |
"headers": [ | |
[ | |
"content-type", | |
"application/javascript" | |
] | |
], | |
"status": 200, | |
"status_text": "" | |
} | |
}, | |
"base_uri": "https://localhost:8080/", | |
"height": 89 | |
}, | |
"id": "2xb2XzyAVJ7c", | |
"outputId": "83b48d48-9a2e-4ed0-aec2-ae47b2b946c4" | |
}, | |
"source": [ | |
"from google.colab import files\r\n", | |
"\r\n", | |
"uploaded = files.upload()\r\n", | |
"\r\n", | |
"for fn in uploaded.keys():\r\n", | |
" print('User uploaded file \"{name}\" with length {length} bytes'.format(name=fn, length=len(uploaded[fn])))" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <input type=\"file\" id=\"files-5cffcc47-a517-4d4d-b07e-24dca62500e1\" name=\"files[]\" multiple disabled\n", | |
" style=\"border:none\" />\n", | |
" <output id=\"result-5cffcc47-a517-4d4d-b07e-24dca62500e1\">\n", | |
" Upload widget is only available when the cell has been executed in the\n", | |
" current browser session. Please rerun this cell to enable.\n", | |
" </output>\n", | |
" <script src=\"/nbextensions/google.colab/files.js\"></script> " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Saving Emotions.zip to Emotions (1).zip\n", | |
"User uploaded file \"Emotions.zip\" with length 1449779 bytes\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6_EC8W23oFgV", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "69f9f080-c931-4c09-bbe6-7245042c2448" | |
}, | |
"source": [ | |
"from zipfile import ZipFile\n", | |
"file_name = \"Emotions.zip\"\n", | |
"\n", | |
"with ZipFile(file_name, 'r') as zip:\n", | |
" zip.extractall()\n", | |
" print(\"Done\")" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Done\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "r9SPc-kQVJ7h" | |
}, | |
"source": [ | |
"test_dir = '/content/Emotions/CK+48/test'\n", | |
"train_dir = '/content/Emotions/CK+48/train'\n", | |
"\n", | |
"train_transforms = transforms.Compose([transforms.RandomRotation(30),\n", | |
" transforms.RandomResizedCrop(100),\n", | |
" transforms.RandomHorizontalFlip(),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize([0.5,0.5,0.5],\n", | |
" [0.5,0.5,0.5])])\n", | |
"test_transforms = transforms.Compose([transforms.Resize(255),\n", | |
" transforms.CenterCrop(224),\n", | |
" transforms.ToTensor()])\n", | |
"train_data = datasets.ImageFolder(train_dir, transform=train_transforms)\n", | |
"trainloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)\n", | |
"\n", | |
"test_data = datasets.ImageFolder(test_dir, transform=test_transforms)\n", | |
"testloader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "V3G-tEcGVJ7j", | |
"outputId": "73fb8343-f625-4fa9-bc60-675417195038" | |
}, | |
"source": [ | |
"model = models.alexnet(pretrained=True)\n", | |
"model\n", | |
"for param in model.parameters():\n", | |
" param.requires_grad = False\n", | |
" \n", | |
"classifier = nn.Sequential(nn.Linear(9216, 1000),\n", | |
" nn.ReLU(),\n", | |
" nn.Dropout(0.2),\n", | |
" nn.Linear(1000, 5),\n", | |
" nn.LogSoftmax(dim=1))\n", | |
"\n", | |
"model.classifier = classifier\n", | |
"\n", | |
"criterion = nn.NLLLoss()\n", | |
"\n", | |
"optimizer = optim.SGD(model.classifier.parameters(), lr=0.003)\n", | |
"model" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"AlexNet(\n", | |
" (features): Sequential(\n", | |
" (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n", | |
" (1): ReLU(inplace=True)\n", | |
" (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
" (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", | |
" (4): ReLU(inplace=True)\n", | |
" (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
" (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
" (7): ReLU(inplace=True)\n", | |
" (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
" (9): ReLU(inplace=True)\n", | |
" (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
" (11): ReLU(inplace=True)\n", | |
" (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
" )\n", | |
" (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n", | |
" (classifier): Sequential(\n", | |
" (0): Linear(in_features=9216, out_features=1000, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Dropout(p=0.2, inplace=False)\n", | |
" (3): Linear(in_features=1000, out_features=5, bias=True)\n", | |
" (4): LogSoftmax(dim=1)\n", | |
" )\n", | |
")" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 25 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZfvJnxANVJ7m", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "a9a6fd29-6ea5-4d47-ba5b-432fa4a717e9" | |
}, | |
"source": [ | |
"epochs = 8\n", | |
"steps = 0\n", | |
"running_loss = 0\n", | |
"print_every = 5\n", | |
"class_number = 5\n", | |
"for epoch in range(epochs):\n", | |
" for inputs, labels in trainloader:\n", | |
" steps += 1\n", | |
"\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" \n", | |
" logps = model.forward(inputs)\n", | |
" loss = criterion(logps, labels)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" running_loss += loss.item()\n", | |
" \n", | |
" if steps % print_every == 0:\n", | |
" test_loss = 0\n", | |
" accuracy = 0\n", | |
" model.eval()\n", | |
" with torch.no_grad():\n", | |
" for inputs, labels in testloader:\n", | |
" \n", | |
" logps = model.forward(inputs)\n", | |
" batch_loss = criterion(logps, labels)\n", | |
" \n", | |
" test_loss += batch_loss.item()\n", | |
" \n", | |
" # Calculate accuracy\n", | |
" ps = torch.exp(logps)\n", | |
" top_p, top_class = ps.topk(1, dim=1)\n", | |
" equals = top_class == labels.view(*top_class.shape)\n", | |
" accuracy += torch.mean(equals.type(torch.FloatTensor)).item()\n", | |
" \n", | |
" print(f\"Epoch {epoch+1}/{epochs}.. \"\n", | |
" f\"Train loss: {running_loss/print_every:.3f}.. \"\n", | |
" f\"Test loss: {test_loss/len(testloader):.3f}.. \"\n", | |
" f\"Test accuracy: {accuracy/len(testloader):.3f}\")\n", | |
" running_loss = 0\n", | |
" model.train()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/8.. Train loss: 1.543.. Test loss: 1.578.. Test accuracy: 0.244\n", | |
"Epoch 1/8.. Train loss: 1.457.. Test loss: 1.573.. Test accuracy: 0.244\n", | |
"Epoch 1/8.. Train loss: 1.422.. Test loss: 1.523.. Test accuracy: 0.244\n", | |
"Epoch 2/8.. Train loss: 1.334.. Test loss: 1.515.. Test accuracy: 0.300\n", | |
"Epoch 2/8.. Train loss: 1.287.. Test loss: 1.541.. Test accuracy: 0.362\n", | |
"Epoch 2/8.. Train loss: 1.242.. Test loss: 1.470.. Test accuracy: 0.350\n", | |
"Epoch 2/8.. Train loss: 1.350.. Test loss: 1.475.. Test accuracy: 0.281\n", | |
"Epoch 3/8.. Train loss: 1.182.. Test loss: 1.436.. Test accuracy: 0.494\n", | |
"Epoch 3/8.. Train loss: 1.213.. Test loss: 1.458.. Test accuracy: 0.500\n", | |
"Epoch 3/8.. Train loss: 1.211.. Test loss: 1.507.. Test accuracy: 0.287\n", | |
"Epoch 3/8.. Train loss: 1.247.. Test loss: 1.484.. Test accuracy: 0.269\n", | |
"Epoch 4/8.. Train loss: 1.208.. Test loss: 1.409.. Test accuracy: 0.588\n", | |
"Epoch 4/8.. Train loss: 1.140.. Test loss: 1.401.. Test accuracy: 0.619\n", | |
"Epoch 4/8.. Train loss: 1.242.. Test loss: 1.421.. Test accuracy: 0.481\n", | |
"Epoch 4/8.. Train loss: 1.122.. Test loss: 1.404.. Test accuracy: 0.656\n", | |
"Epoch 5/8.. Train loss: 1.099.. Test loss: 1.431.. Test accuracy: 0.394\n", | |
"Epoch 5/8.. Train loss: 1.136.. Test loss: 1.377.. Test accuracy: 0.575\n", | |
"Epoch 5/8.. Train loss: 1.066.. Test loss: 1.399.. Test accuracy: 0.569\n", | |
"Epoch 5/8.. Train loss: 1.186.. Test loss: 1.351.. Test accuracy: 0.525\n", | |
"Epoch 6/8.. Train loss: 1.132.. Test loss: 1.362.. Test accuracy: 0.631\n", | |
"Epoch 6/8.. Train loss: 1.064.. Test loss: 1.361.. Test accuracy: 0.481\n", | |
"Epoch 6/8.. Train loss: 1.169.. Test loss: 1.359.. Test accuracy: 0.656\n", | |
"Epoch 7/8.. Train loss: 0.951.. Test loss: 1.360.. Test accuracy: 0.463\n", | |
"Epoch 7/8.. Train loss: 1.101.. Test loss: 1.367.. Test accuracy: 0.600\n", | |
"Epoch 7/8.. Train loss: 1.064.. Test loss: 1.352.. Test accuracy: 0.550\n", | |
"Epoch 7/8.. Train loss: 1.061.. Test loss: 1.353.. Test accuracy: 0.650\n", | |
"Epoch 8/8.. Train loss: 1.064.. Test loss: 1.360.. Test accuracy: 0.606\n", | |
"Epoch 8/8.. Train loss: 1.153.. Test loss: 1.343.. Test accuracy: 0.613\n", | |
"Epoch 8/8.. Train loss: 0.973.. Test loss: 1.355.. Test accuracy: 0.606\n", | |
"Epoch 8/8.. Train loss: 1.012.. Test loss: 1.311.. Test accuracy: 0.606\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZGEqzQSpVJ7n" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5Vc54xUrVJ7o" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment