Skip to content

Instantly share code, notes, and snippets.

@Strady123
Last active February 18, 2021 21:04
Show Gist options
  • Save Strady123/949f3d3ac8257c1e8e168a08460d1b54 to your computer and use it in GitHub Desktop.
Save Strady123/949f3d3ac8257c1e8e168a08460d1b54 to your computer and use it in GitHub Desktop.
Emotion Recognition.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Emotion Recognition.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Strady123/949f3d3ac8257c1e8e168a08460d1b54/emotion-recognition.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0cxljsVPVJ7S"
},
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"from torchvision import datasets, transforms, models"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
"ok": true,
"headers": [
[
"content-type",
"application/javascript"
]
],
"status": 200,
"status_text": ""
}
},
"base_uri": "https://localhost:8080/",
"height": 89
},
"id": "2xb2XzyAVJ7c",
"outputId": "83b48d48-9a2e-4ed0-aec2-ae47b2b946c4"
},
"source": [
"from google.colab import files\r\n",
"\r\n",
"uploaded = files.upload()\r\n",
"\r\n",
"for fn in uploaded.keys():\r\n",
" print('User uploaded file \"{name}\" with length {length} bytes'.format(name=fn, length=len(uploaded[fn])))"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-5cffcc47-a517-4d4d-b07e-24dca62500e1\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-5cffcc47-a517-4d4d-b07e-24dca62500e1\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script src=\"/nbextensions/google.colab/files.js\"></script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Saving Emotions.zip to Emotions (1).zip\n",
"User uploaded file \"Emotions.zip\" with length 1449779 bytes\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6_EC8W23oFgV",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "69f9f080-c931-4c09-bbe6-7245042c2448"
},
"source": [
"from zipfile import ZipFile\n",
"file_name = \"Emotions.zip\"\n",
"\n",
"with ZipFile(file_name, 'r') as zip:\n",
" zip.extractall()\n",
" print(\"Done\")"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Done\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "r9SPc-kQVJ7h"
},
"source": [
"test_dir = '/content/Emotions/CK+48/test'\n",
"train_dir = '/content/Emotions/CK+48/train'\n",
"\n",
"train_transforms = transforms.Compose([transforms.RandomRotation(30),\n",
" transforms.RandomResizedCrop(100),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.5,0.5,0.5],\n",
" [0.5,0.5,0.5])])\n",
"test_transforms = transforms.Compose([transforms.Resize(255),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor()])\n",
"train_data = datasets.ImageFolder(train_dir, transform=train_transforms)\n",
"trainloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)\n",
"\n",
"test_data = datasets.ImageFolder(test_dir, transform=test_transforms)\n",
"testloader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "V3G-tEcGVJ7j",
"outputId": "73fb8343-f625-4fa9-bc60-675417195038"
},
"source": [
"model = models.alexnet(pretrained=True)\n",
"model\n",
"for param in model.parameters():\n",
" param.requires_grad = False\n",
" \n",
"classifier = nn.Sequential(nn.Linear(9216, 1000),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.2),\n",
" nn.Linear(1000, 5),\n",
" nn.LogSoftmax(dim=1))\n",
"\n",
"model.classifier = classifier\n",
"\n",
"criterion = nn.NLLLoss()\n",
"\n",
"optimizer = optim.SGD(model.classifier.parameters(), lr=0.003)\n",
"model"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"AlexNet(\n",
" (features): Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
" (1): ReLU(inplace=True)\n",
" (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
" (4): ReLU(inplace=True)\n",
" (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (7): ReLU(inplace=True)\n",
" (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (9): ReLU(inplace=True)\n",
" (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (11): ReLU(inplace=True)\n",
" (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
" (classifier): Sequential(\n",
" (0): Linear(in_features=9216, out_features=1000, bias=True)\n",
" (1): ReLU()\n",
" (2): Dropout(p=0.2, inplace=False)\n",
" (3): Linear(in_features=1000, out_features=5, bias=True)\n",
" (4): LogSoftmax(dim=1)\n",
" )\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZfvJnxANVJ7m",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a9a6fd29-6ea5-4d47-ba5b-432fa4a717e9"
},
"source": [
"epochs = 8\n",
"steps = 0\n",
"running_loss = 0\n",
"print_every = 5\n",
"class_number = 5\n",
"for epoch in range(epochs):\n",
" for inputs, labels in trainloader:\n",
" steps += 1\n",
"\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" logps = model.forward(inputs)\n",
" loss = criterion(logps, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" running_loss += loss.item()\n",
" \n",
" if steps % print_every == 0:\n",
" test_loss = 0\n",
" accuracy = 0\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for inputs, labels in testloader:\n",
" \n",
" logps = model.forward(inputs)\n",
" batch_loss = criterion(logps, labels)\n",
" \n",
" test_loss += batch_loss.item()\n",
" \n",
" # Calculate accuracy\n",
" ps = torch.exp(logps)\n",
" top_p, top_class = ps.topk(1, dim=1)\n",
" equals = top_class == labels.view(*top_class.shape)\n",
" accuracy += torch.mean(equals.type(torch.FloatTensor)).item()\n",
" \n",
" print(f\"Epoch {epoch+1}/{epochs}.. \"\n",
" f\"Train loss: {running_loss/print_every:.3f}.. \"\n",
" f\"Test loss: {test_loss/len(testloader):.3f}.. \"\n",
" f\"Test accuracy: {accuracy/len(testloader):.3f}\")\n",
" running_loss = 0\n",
" model.train()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/8.. Train loss: 1.543.. Test loss: 1.578.. Test accuracy: 0.244\n",
"Epoch 1/8.. Train loss: 1.457.. Test loss: 1.573.. Test accuracy: 0.244\n",
"Epoch 1/8.. Train loss: 1.422.. Test loss: 1.523.. Test accuracy: 0.244\n",
"Epoch 2/8.. Train loss: 1.334.. Test loss: 1.515.. Test accuracy: 0.300\n",
"Epoch 2/8.. Train loss: 1.287.. Test loss: 1.541.. Test accuracy: 0.362\n",
"Epoch 2/8.. Train loss: 1.242.. Test loss: 1.470.. Test accuracy: 0.350\n",
"Epoch 2/8.. Train loss: 1.350.. Test loss: 1.475.. Test accuracy: 0.281\n",
"Epoch 3/8.. Train loss: 1.182.. Test loss: 1.436.. Test accuracy: 0.494\n",
"Epoch 3/8.. Train loss: 1.213.. Test loss: 1.458.. Test accuracy: 0.500\n",
"Epoch 3/8.. Train loss: 1.211.. Test loss: 1.507.. Test accuracy: 0.287\n",
"Epoch 3/8.. Train loss: 1.247.. Test loss: 1.484.. Test accuracy: 0.269\n",
"Epoch 4/8.. Train loss: 1.208.. Test loss: 1.409.. Test accuracy: 0.588\n",
"Epoch 4/8.. Train loss: 1.140.. Test loss: 1.401.. Test accuracy: 0.619\n",
"Epoch 4/8.. Train loss: 1.242.. Test loss: 1.421.. Test accuracy: 0.481\n",
"Epoch 4/8.. Train loss: 1.122.. Test loss: 1.404.. Test accuracy: 0.656\n",
"Epoch 5/8.. Train loss: 1.099.. Test loss: 1.431.. Test accuracy: 0.394\n",
"Epoch 5/8.. Train loss: 1.136.. Test loss: 1.377.. Test accuracy: 0.575\n",
"Epoch 5/8.. Train loss: 1.066.. Test loss: 1.399.. Test accuracy: 0.569\n",
"Epoch 5/8.. Train loss: 1.186.. Test loss: 1.351.. Test accuracy: 0.525\n",
"Epoch 6/8.. Train loss: 1.132.. Test loss: 1.362.. Test accuracy: 0.631\n",
"Epoch 6/8.. Train loss: 1.064.. Test loss: 1.361.. Test accuracy: 0.481\n",
"Epoch 6/8.. Train loss: 1.169.. Test loss: 1.359.. Test accuracy: 0.656\n",
"Epoch 7/8.. Train loss: 0.951.. Test loss: 1.360.. Test accuracy: 0.463\n",
"Epoch 7/8.. Train loss: 1.101.. Test loss: 1.367.. Test accuracy: 0.600\n",
"Epoch 7/8.. Train loss: 1.064.. Test loss: 1.352.. Test accuracy: 0.550\n",
"Epoch 7/8.. Train loss: 1.061.. Test loss: 1.353.. Test accuracy: 0.650\n",
"Epoch 8/8.. Train loss: 1.064.. Test loss: 1.360.. Test accuracy: 0.606\n",
"Epoch 8/8.. Train loss: 1.153.. Test loss: 1.343.. Test accuracy: 0.613\n",
"Epoch 8/8.. Train loss: 0.973.. Test loss: 1.355.. Test accuracy: 0.606\n",
"Epoch 8/8.. Train loss: 1.012.. Test loss: 1.311.. Test accuracy: 0.606\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZGEqzQSpVJ7n"
},
"source": [
""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5Vc54xUrVJ7o"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment