Skip to content

Instantly share code, notes, and snippets.

@l0g1c-80m8
Last active April 19, 2024 06:19
Show Gist options
  • Save l0g1c-80m8/b420fbfe647d9683cc80fd6f30e87115 to your computer and use it in GitHub Desktop.
Save l0g1c-80m8/b420fbfe647d9683cc80fd6f30e87115 to your computer and use it in GitHub Desktop.
optimization_tutorial.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/l0g1c-80m8/b420fbfe647d9683cc80fd6f30e87115/optimization_tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "zCTtcXL-cw56"
},
"outputs": [],
"source": [
"# For tips on running notebooks in Google Colab, see\n",
"# https://pytorch.org/tutorials/beginner/colab\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qq4jbIyHcw57"
},
"source": [
"[Learn the Basics](intro.html) \\|\\|\n",
"[Quickstart](quickstart_tutorial.html) \\|\\|\n",
"[Tensors](tensorqs_tutorial.html) \\|\\| [Datasets &\n",
"DataLoaders](data_tutorial.html) \\|\\|\n",
"[Transforms](transforms_tutorial.html) \\|\\| [Build\n",
"Model](buildmodel_tutorial.html) \\|\\|\n",
"[Autograd](autogradqs_tutorial.html) \\|\\| **Optimization** \\|\\| [Save &\n",
"Load Model](saveloadrun_tutorial.html)\n",
"\n",
"Optimizing Model Parameters\n",
"===========================\n",
"\n",
"Now that we have a model and data it\\'s time to train, validate and test\n",
"our model by optimizing its parameters on our data. Training a model is\n",
"an iterative process; in each iteration the model makes a guess about\n",
"the output, calculates the error in its guess (*loss*), collects the\n",
"derivatives of the error with respect to its parameters (as we saw in\n",
"the [previous section](autograd_tutorial.html)), and **optimizes** these\n",
"parameters using gradient descent. For a more detailed walkthrough of\n",
"this process, check out this video on [backpropagation from\n",
"3Blue1Brown](https://www.youtube.com/watch?v=tIeHLnjs5U8).\n",
"\n",
"Prerequisite Code\n",
"-----------------\n",
"\n",
"We load the code from the previous sections on [Datasets &\n",
"DataLoaders](data_tutorial.html) and [Build\n",
"Model](buildmodel_tutorial.html).\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "dnkn_JLrcw57"
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets\n",
"from torchvision.transforms import ToTensor\n",
"\n",
"training_data = datasets.FashionMNIST(\n",
" root=\"data\",\n",
" train=True,\n",
" download=True,\n",
" transform=ToTensor()\n",
")\n",
"\n",
"test_data = datasets.FashionMNIST(\n",
" root=\"data\",\n",
" train=False,\n",
" download=True,\n",
" transform=ToTensor()\n",
")\n",
"\n",
"train_dataloader = DataLoader(training_data, batch_size=64)\n",
"test_dataloader = DataLoader(test_data, batch_size=64)\n",
"\n",
"class NeuralNetwork(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.flatten = nn.Flatten()\n",
" self.linear_relu_stack = nn.Sequential(\n",
" nn.Linear(28*28, 512),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 512),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 10),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.flatten(x)\n",
" logits = self.linear_relu_stack(x)\n",
" return logits\n",
"\n",
"model = NeuralNetwork()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mI-3Rq8ccw58"
},
"source": [
"Hyperparameters\n",
"===============\n",
"\n",
"Hyperparameters are adjustable parameters that let you control the model\n",
"optimization process. Different hyperparameter values can impact model\n",
"training and convergence rates ([read\n",
"more](https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html)\n",
"about hyperparameter tuning)\n",
"\n",
"We define the following hyperparameters for training:\n",
"\n",
": - **Number of Epochs** - the number times to iterate over the\n",
" dataset\n",
" - **Batch Size** - the number of data samples propagated through\n",
" the network before the parameters are updated\n",
" - **Learning Rate** - how much to update models parameters at each\n",
" batch/epoch. Smaller values yield slow learning speed, while\n",
" large values may result in unpredictable behavior during\n",
" training.\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "puF_Y5Gicw58"
},
"outputs": [],
"source": [
"learning_rate = 1e-3\n",
"batch_size = 64\n",
"epochs = 5"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lYBJALlOcw58"
},
"source": [
"Optimization Loop\n",
"=================\n",
"\n",
"Once we set our hyperparameters, we can then train and optimize our\n",
"model with an optimization loop. Each iteration of the optimization loop\n",
"is called an **epoch**.\n",
"\n",
"Each epoch consists of two main parts:\n",
"\n",
": - **The Train Loop** - iterate over the training dataset and try\n",
" to converge to optimal parameters.\n",
" - **The Validation/Test Loop** - iterate over the test dataset to\n",
" check if model performance is improving.\n",
"\n",
"Let\\'s briefly familiarize ourselves with some of the concepts used in\n",
"the training loop. Jump ahead to see the\n",
"`full-impl-label`{.interpreted-text role=\"ref\"} of the optimization\n",
"loop.\n",
"\n",
"Loss Function\n",
"-------------\n",
"\n",
"When presented with some training data, our untrained network is likely\n",
"not to give the correct answer. **Loss function** measures the degree of\n",
"dissimilarity of obtained result to the target value, and it is the loss\n",
"function that we want to minimize during training. To calculate the loss\n",
"we make a prediction using the inputs of our given data sample and\n",
"compare it against the true data label value.\n",
"\n",
"Common loss functions include\n",
"[nn.MSELoss](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss)\n",
"(Mean Square Error) for regression tasks, and\n",
"[nn.NLLLoss](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss)\n",
"(Negative Log Likelihood) for classification.\n",
"[nn.CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss)\n",
"combines `nn.LogSoftmax` and `nn.NLLLoss`.\n",
"\n",
"We pass our model\\'s output logits to `nn.CrossEntropyLoss`, which will\n",
"normalize the logits and compute the prediction error.\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "2jtr0ZjQcw59"
},
"outputs": [],
"source": [
"# Initialize the loss function\n",
"loss_fn = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EyZGrGJcw59"
},
"source": [
"Optimizer\n",
"=========\n",
"\n",
"Optimization is the process of adjusting model parameters to reduce\n",
"model error in each training step. **Optimization algorithms** define\n",
"how this process is performed (in this example we use Stochastic\n",
"Gradient Descent). All optimization logic is encapsulated in the\n",
"`optimizer` object. Here, we use the SGD optimizer; additionally, there\n",
"are many [different\n",
"optimizers](https://pytorch.org/docs/stable/optim.html) available in\n",
"PyTorch such as ADAM and RMSProp, that work better for different kinds\n",
"of models and data.\n",
"\n",
"We initialize the optimizer by registering the model\\'s parameters that\n",
"need to be trained, and passing in the learning rate hyperparameter.\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "Mjq9-ouvcw59"
},
"outputs": [],
"source": [
"optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HTqeQvwVcw5-"
},
"source": [
"Inside the training loop, optimization happens in three steps:\n",
"\n",
": - Call `optimizer.zero_grad()` to reset the gradients of model\n",
" parameters. Gradients by default add up; to prevent\n",
" double-counting, we explicitly zero them at each iteration.\n",
" - Backpropagate the prediction loss with a call to\n",
" `loss.backward()`. PyTorch deposits the gradients of the loss\n",
" w.r.t. each parameter.\n",
" - Once we have our gradients, we call `optimizer.step()` to adjust\n",
" the parameters by the gradients collected in the backward pass.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h19ouiR7cw5-"
},
"source": [
"Full Implementation {#full-impl-label}\n",
"===================\n",
"\n",
"We define `train_loop` that loops over our optimization code, and\n",
"`test_loop` that evaluates the model\\'s performance against our test\n",
"data.\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "91eTnjEmcw5-"
},
"outputs": [],
"source": [
"def train_loop(dataloader, model, loss_fn, optimizer):\n",
" size = len(dataloader.dataset)\n",
" # Set the model to training mode - important for batch normalization and dropout layers\n",
" # Unnecessary in this situation but added for best practices\n",
" model.train()\n",
" for batch, (X, y) in enumerate(dataloader):\n",
" # Compute prediction and loss\n",
" pred = model(X)\n",
" loss = loss_fn(pred, y)\n",
"\n",
" # Backpropagation\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" if batch % 100 == 0:\n",
" loss, current = loss.item(), batch * batch_size + len(X)\n",
" print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")\n",
"\n",
"\n",
"def test_loop(dataloader, model, loss_fn):\n",
" # Set the model to evaluation mode - important for batch normalization and dropout layers\n",
" # Unnecessary in this situation but added for best practices\n",
" model.eval()\n",
" size = len(dataloader.dataset)\n",
" num_batches = len(dataloader)\n",
" test_loss, correct = 0, 0\n",
"\n",
" # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode\n",
" # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True\n",
" with torch.no_grad():\n",
" for X, y in dataloader:\n",
" pred = model(X)\n",
" test_loss += loss_fn(pred, y).item()\n",
" correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
"\n",
" test_loss /= num_batches\n",
" correct /= size\n",
" print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E7ui2Uykcw5-"
},
"source": [
"We initialize the loss function and optimizer, and pass it to\n",
"`train_loop` and `test_loop`. Feel free to increase the number of epochs\n",
"to track the model\\'s improving performance.\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "J58D8rx8cw5-",
"outputId": "ddf8531b-3dea-4293-98b2-7ae7e736ddbb"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1\n",
"-------------------------------\n",
"loss: 2.290310 [ 64/60000]\n",
"loss: 2.275503 [ 6464/60000]\n",
"loss: 2.263773 [12864/60000]\n",
"loss: 2.269438 [19264/60000]\n",
"loss: 2.220550 [25664/60000]\n",
"loss: 2.210333 [32064/60000]\n",
"loss: 2.211369 [38464/60000]\n",
"loss: 2.176098 [44864/60000]\n",
"loss: 2.172240 [51264/60000]\n",
"loss: 2.146804 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 37.0%, Avg loss: 2.133305 \n",
"\n",
"Epoch 2\n",
"-------------------------------\n",
"loss: 2.137579 [ 64/60000]\n",
"loss: 2.127795 [ 6464/60000]\n",
"loss: 2.072240 [12864/60000]\n",
"loss: 2.100680 [19264/60000]\n",
"loss: 2.013661 [25664/60000]\n",
"loss: 1.973200 [32064/60000]\n",
"loss: 1.996381 [38464/60000]\n",
"loss: 1.911712 [44864/60000]\n",
"loss: 1.911871 [51264/60000]\n",
"loss: 1.847912 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 57.1%, Avg loss: 1.840456 \n",
"\n",
"Epoch 3\n",
"-------------------------------\n",
"loss: 1.871025 [ 64/60000]\n",
"loss: 1.842100 [ 6464/60000]\n",
"loss: 1.721378 [12864/60000]\n",
"loss: 1.773240 [19264/60000]\n",
"loss: 1.640321 [25664/60000]\n",
"loss: 1.608667 [32064/60000]\n",
"loss: 1.631004 [38464/60000]\n",
"loss: 1.532383 [44864/60000]\n",
"loss: 1.553534 [51264/60000]\n",
"loss: 1.454071 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 62.4%, Avg loss: 1.473399 \n",
"\n",
"Epoch 4\n",
"-------------------------------\n",
"loss: 1.540690 [ 64/60000]\n",
"loss: 1.509801 [ 6464/60000]\n",
"loss: 1.358088 [12864/60000]\n",
"loss: 1.435713 [19264/60000]\n",
"loss: 1.308134 [25664/60000]\n",
"loss: 1.311491 [32064/60000]\n",
"loss: 1.327095 [38464/60000]\n",
"loss: 1.254121 [44864/60000]\n",
"loss: 1.286323 [51264/60000]\n",
"loss: 1.191111 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 64.0%, Avg loss: 1.220317 \n",
"\n",
"Epoch 5\n",
"-------------------------------\n",
"loss: 1.296934 [ 64/60000]\n",
"loss: 1.283077 [ 6464/60000]\n",
"loss: 1.115692 [12864/60000]\n",
"loss: 1.222794 [19264/60000]\n",
"loss: 1.096485 [25664/60000]\n",
"loss: 1.121398 [32064/60000]\n",
"loss: 1.144145 [38464/60000]\n",
"loss: 1.084736 [44864/60000]\n",
"loss: 1.120560 [51264/60000]\n",
"loss: 1.040519 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 65.1%, Avg loss: 1.065340 \n",
"\n",
"Epoch 6\n",
"-------------------------------\n",
"loss: 1.136324 [ 64/60000]\n",
"loss: 1.141430 [ 6464/60000]\n",
"loss: 0.957609 [12864/60000]\n",
"loss: 1.089615 [19264/60000]\n",
"loss: 0.968067 [25664/60000]\n",
"loss: 0.995774 [32064/60000]\n",
"loss: 1.032992 [38464/60000]\n",
"loss: 0.978843 [44864/60000]\n",
"loss: 1.011811 [51264/60000]\n",
"loss: 0.946773 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 66.2%, Avg loss: 0.965635 \n",
"\n",
"Epoch 7\n",
"-------------------------------\n",
"loss: 1.024377 [ 64/60000]\n",
"loss: 1.049673 [ 6464/60000]\n",
"loss: 0.849768 [12864/60000]\n",
"loss: 1.000587 [19264/60000]\n",
"loss: 0.887710 [25664/60000]\n",
"loss: 0.908204 [32064/60000]\n",
"loss: 0.960904 [38464/60000]\n",
"loss: 0.910942 [44864/60000]\n",
"loss: 0.936227 [51264/60000]\n",
"loss: 0.883894 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 67.4%, Avg loss: 0.897660 \n",
"\n",
"Epoch 8\n",
"-------------------------------\n",
"loss: 0.942047 [ 64/60000]\n",
"loss: 0.985817 [ 6464/60000]\n",
"loss: 0.772774 [12864/60000]\n",
"loss: 0.937112 [19264/60000]\n",
"loss: 0.834187 [25664/60000]\n",
"loss: 0.844302 [32064/60000]\n",
"loss: 0.910575 [38464/60000]\n",
"loss: 0.865833 [44864/60000]\n",
"loss: 0.881442 [51264/60000]\n",
"loss: 0.838021 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 68.6%, Avg loss: 0.848459 \n",
"\n",
"Epoch 9\n",
"-------------------------------\n",
"loss: 0.878418 [ 64/60000]\n",
"loss: 0.937823 [ 6464/60000]\n",
"loss: 0.715006 [12864/60000]\n",
"loss: 0.889733 [19264/60000]\n",
"loss: 0.795827 [25664/60000]\n",
"loss: 0.795978 [32064/60000]\n",
"loss: 0.872392 [38464/60000]\n",
"loss: 0.834642 [44864/60000]\n",
"loss: 0.840289 [51264/60000]\n",
"loss: 0.802433 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 69.9%, Avg loss: 0.810912 \n",
"\n",
"Epoch 10\n",
"-------------------------------\n",
"loss: 0.826868 [ 64/60000]\n",
"loss: 0.899161 [ 6464/60000]\n",
"loss: 0.669611 [12864/60000]\n",
"loss: 0.852966 [19264/60000]\n",
"loss: 0.766368 [25664/60000]\n",
"loss: 0.758327 [32064/60000]\n",
"loss: 0.841313 [38464/60000]\n",
"loss: 0.811713 [44864/60000]\n",
"loss: 0.808279 [51264/60000]\n",
"loss: 0.773592 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 71.1%, Avg loss: 0.780833 \n",
"\n",
"Epoch 11\n",
"-------------------------------\n",
"loss: 0.783718 [ 64/60000]\n",
"loss: 0.866302 [ 6464/60000]\n",
"loss: 0.632861 [12864/60000]\n",
"loss: 0.823387 [19264/60000]\n",
"loss: 0.742387 [25664/60000]\n",
"loss: 0.728453 [32064/60000]\n",
"loss: 0.814751 [38464/60000]\n",
"loss: 0.793659 [44864/60000]\n",
"loss: 0.782221 [51264/60000]\n",
"loss: 0.749092 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 72.2%, Avg loss: 0.755679 \n",
"\n",
"Epoch 12\n",
"-------------------------------\n",
"loss: 0.746509 [ 64/60000]\n",
"loss: 0.837352 [ 6464/60000]\n",
"loss: 0.602107 [12864/60000]\n",
"loss: 0.799123 [19264/60000]\n",
"loss: 0.721948 [25664/60000]\n",
"loss: 0.704099 [32064/60000]\n",
"loss: 0.791141 [38464/60000]\n",
"loss: 0.778566 [44864/60000]\n",
"loss: 0.760334 [51264/60000]\n",
"loss: 0.727574 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 73.3%, Avg loss: 0.733872 \n",
"\n",
"Epoch 13\n",
"-------------------------------\n",
"loss: 0.713826 [ 64/60000]\n",
"loss: 0.811272 [ 6464/60000]\n",
"loss: 0.575667 [12864/60000]\n",
"loss: 0.778658 [19264/60000]\n",
"loss: 0.704210 [25664/60000]\n",
"loss: 0.683656 [32064/60000]\n",
"loss: 0.769651 [38464/60000]\n",
"loss: 0.765350 [44864/60000]\n",
"loss: 0.741397 [51264/60000]\n",
"loss: 0.708255 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 74.0%, Avg loss: 0.714431 \n",
"\n",
"Epoch 14\n",
"-------------------------------\n",
"loss: 0.684894 [ 64/60000]\n",
"loss: 0.787469 [ 6464/60000]\n",
"loss: 0.552428 [12864/60000]\n",
"loss: 0.760903 [19264/60000]\n",
"loss: 0.688572 [25664/60000]\n",
"loss: 0.666294 [32064/60000]\n",
"loss: 0.749718 [38464/60000]\n",
"loss: 0.753447 [44864/60000]\n",
"loss: 0.724810 [51264/60000]\n",
"loss: 0.690653 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 74.9%, Avg loss: 0.696807 \n",
"\n",
"Epoch 15\n",
"-------------------------------\n",
"loss: 0.658878 [ 64/60000]\n",
"loss: 0.765472 [ 6464/60000]\n",
"loss: 0.531911 [12864/60000]\n",
"loss: 0.745321 [19264/60000]\n",
"loss: 0.674703 [25664/60000]\n",
"loss: 0.651341 [32064/60000]\n",
"loss: 0.730811 [38464/60000]\n",
"loss: 0.742618 [44864/60000]\n",
"loss: 0.710211 [51264/60000]\n",
"loss: 0.674493 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 75.7%, Avg loss: 0.680665 \n",
"\n",
"Epoch 16\n",
"-------------------------------\n",
"loss: 0.635415 [ 64/60000]\n",
"loss: 0.745032 [ 6464/60000]\n",
"loss: 0.513653 [12864/60000]\n",
"loss: 0.731463 [19264/60000]\n",
"loss: 0.662322 [25664/60000]\n",
"loss: 0.638469 [32064/60000]\n",
"loss: 0.712931 [38464/60000]\n",
"loss: 0.732589 [44864/60000]\n",
"loss: 0.697231 [51264/60000]\n",
"loss: 0.659461 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 76.4%, Avg loss: 0.665802 \n",
"\n",
"Epoch 17\n",
"-------------------------------\n",
"loss: 0.614198 [ 64/60000]\n",
"loss: 0.726068 [ 6464/60000]\n",
"loss: 0.497212 [12864/60000]\n",
"loss: 0.718895 [19264/60000]\n",
"loss: 0.651337 [25664/60000]\n",
"loss: 0.627272 [32064/60000]\n",
"loss: 0.696128 [38464/60000]\n",
"loss: 0.723555 [44864/60000]\n",
"loss: 0.685712 [51264/60000]\n",
"loss: 0.645445 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 77.1%, Avg loss: 0.652078 \n",
"\n",
"Epoch 18\n",
"-------------------------------\n",
"loss: 0.594977 [ 64/60000]\n",
"loss: 0.708432 [ 6464/60000]\n",
"loss: 0.482356 [12864/60000]\n",
"loss: 0.707420 [19264/60000]\n",
"loss: 0.641477 [25664/60000]\n",
"loss: 0.617510 [32064/60000]\n",
"loss: 0.680286 [38464/60000]\n",
"loss: 0.715482 [44864/60000]\n",
"loss: 0.675653 [51264/60000]\n",
"loss: 0.632358 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 77.6%, Avg loss: 0.639424 \n",
"\n",
"Epoch 19\n",
"-------------------------------\n",
"loss: 0.577329 [ 64/60000]\n",
"loss: 0.692126 [ 6464/60000]\n",
"loss: 0.468950 [12864/60000]\n",
"loss: 0.696804 [19264/60000]\n",
"loss: 0.632574 [25664/60000]\n",
"loss: 0.608918 [32064/60000]\n",
"loss: 0.665394 [38464/60000]\n",
"loss: 0.708410 [44864/60000]\n",
"loss: 0.666947 [51264/60000]\n",
"loss: 0.620008 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 78.0%, Avg loss: 0.627768 \n",
"\n",
"Epoch 20\n",
"-------------------------------\n",
"loss: 0.561282 [ 64/60000]\n",
"loss: 0.677028 [ 6464/60000]\n",
"loss: 0.456854 [12864/60000]\n",
"loss: 0.686913 [19264/60000]\n",
"loss: 0.624497 [25664/60000]\n",
"loss: 0.601413 [32064/60000]\n",
"loss: 0.651427 [38464/60000]\n",
"loss: 0.702277 [44864/60000]\n",
"loss: 0.659496 [51264/60000]\n",
"loss: 0.608285 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 78.6%, Avg loss: 0.617033 \n",
"\n",
"Epoch 21\n",
"-------------------------------\n",
"loss: 0.546620 [ 64/60000]\n",
"loss: 0.663023 [ 6464/60000]\n",
"loss: 0.445868 [12864/60000]\n",
"loss: 0.677680 [19264/60000]\n",
"loss: 0.617140 [25664/60000]\n",
"loss: 0.594627 [32064/60000]\n",
"loss: 0.638270 [38464/60000]\n",
"loss: 0.696984 [44864/60000]\n",
"loss: 0.653212 [51264/60000]\n",
"loss: 0.597209 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 79.0%, Avg loss: 0.607146 \n",
"\n",
"Epoch 22\n",
"-------------------------------\n",
"loss: 0.533109 [ 64/60000]\n",
"loss: 0.650054 [ 6464/60000]\n",
"loss: 0.435824 [12864/60000]\n",
"loss: 0.669078 [19264/60000]\n",
"loss: 0.610376 [25664/60000]\n",
"loss: 0.588472 [32064/60000]\n",
"loss: 0.626010 [38464/60000]\n",
"loss: 0.692407 [44864/60000]\n",
"loss: 0.647964 [51264/60000]\n",
"loss: 0.586720 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 79.3%, Avg loss: 0.598045 \n",
"\n",
"Epoch 23\n",
"-------------------------------\n",
"loss: 0.520617 [ 64/60000]\n",
"loss: 0.638000 [ 6464/60000]\n",
"loss: 0.426609 [12864/60000]\n",
"loss: 0.661047 [19264/60000]\n",
"loss: 0.603926 [25664/60000]\n",
"loss: 0.582840 [32064/60000]\n",
"loss: 0.614518 [38464/60000]\n",
"loss: 0.688568 [44864/60000]\n",
"loss: 0.643534 [51264/60000]\n",
"loss: 0.576754 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 79.6%, Avg loss: 0.589667 \n",
"\n",
"Epoch 24\n",
"-------------------------------\n",
"loss: 0.509014 [ 64/60000]\n",
"loss: 0.626770 [ 6464/60000]\n",
"loss: 0.418196 [12864/60000]\n",
"loss: 0.653482 [19264/60000]\n",
"loss: 0.597626 [25664/60000]\n",
"loss: 0.577576 [32064/60000]\n",
"loss: 0.603678 [38464/60000]\n",
"loss: 0.685424 [44864/60000]\n",
"loss: 0.639777 [51264/60000]\n",
"loss: 0.567313 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 79.9%, Avg loss: 0.581940 \n",
"\n",
"Epoch 25\n",
"-------------------------------\n",
"loss: 0.498259 [ 64/60000]\n",
"loss: 0.616298 [ 6464/60000]\n",
"loss: 0.410437 [12864/60000]\n",
"loss: 0.646339 [19264/60000]\n",
"loss: 0.591561 [25664/60000]\n",
"loss: 0.572627 [32064/60000]\n",
"loss: 0.593537 [38464/60000]\n",
"loss: 0.682846 [44864/60000]\n",
"loss: 0.636692 [51264/60000]\n",
"loss: 0.558259 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 80.1%, Avg loss: 0.574811 \n",
"\n",
"Epoch 26\n",
"-------------------------------\n",
"loss: 0.488200 [ 64/60000]\n",
"loss: 0.606588 [ 6464/60000]\n",
"loss: 0.403253 [12864/60000]\n",
"loss: 0.639518 [19264/60000]\n",
"loss: 0.585773 [25664/60000]\n",
"loss: 0.567957 [32064/60000]\n",
"loss: 0.584037 [38464/60000]\n",
"loss: 0.680813 [44864/60000]\n",
"loss: 0.634074 [51264/60000]\n",
"loss: 0.549578 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 80.2%, Avg loss: 0.568215 \n",
"\n",
"Epoch 27\n",
"-------------------------------\n",
"loss: 0.478822 [ 64/60000]\n",
"loss: 0.597548 [ 6464/60000]\n",
"loss: 0.396629 [12864/60000]\n",
"loss: 0.633016 [19264/60000]\n",
"loss: 0.580057 [25664/60000]\n",
"loss: 0.563566 [32064/60000]\n",
"loss: 0.575109 [38464/60000]\n",
"loss: 0.679224 [44864/60000]\n",
"loss: 0.631918 [51264/60000]\n",
"loss: 0.541276 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 80.4%, Avg loss: 0.562098 \n",
"\n",
"Epoch 28\n",
"-------------------------------\n",
"loss: 0.470002 [ 64/60000]\n",
"loss: 0.589186 [ 6464/60000]\n",
"loss: 0.390406 [12864/60000]\n",
"loss: 0.626801 [19264/60000]\n",
"loss: 0.574257 [25664/60000]\n",
"loss: 0.559269 [32064/60000]\n",
"loss: 0.566882 [38464/60000]\n",
"loss: 0.677975 [44864/60000]\n",
"loss: 0.629997 [51264/60000]\n",
"loss: 0.533314 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 80.6%, Avg loss: 0.556419 \n",
"\n",
"Epoch 29\n",
"-------------------------------\n",
"loss: 0.461641 [ 64/60000]\n",
"loss: 0.581337 [ 6464/60000]\n",
"loss: 0.384643 [12864/60000]\n",
"loss: 0.620840 [19264/60000]\n",
"loss: 0.568534 [25664/60000]\n",
"loss: 0.554984 [32064/60000]\n",
"loss: 0.559268 [38464/60000]\n",
"loss: 0.676945 [44864/60000]\n",
"loss: 0.628250 [51264/60000]\n",
"loss: 0.525668 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 80.8%, Avg loss: 0.551136 \n",
"\n",
"Epoch 30\n",
"-------------------------------\n",
"loss: 0.453716 [ 64/60000]\n",
"loss: 0.574053 [ 6464/60000]\n",
"loss: 0.379212 [12864/60000]\n",
"loss: 0.615121 [19264/60000]\n",
"loss: 0.562905 [25664/60000]\n",
"loss: 0.550725 [32064/60000]\n",
"loss: 0.552154 [38464/60000]\n",
"loss: 0.676232 [44864/60000]\n",
"loss: 0.626677 [51264/60000]\n",
"loss: 0.518330 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.0%, Avg loss: 0.546218 \n",
"\n",
"Epoch 31\n",
"-------------------------------\n",
"loss: 0.446166 [ 64/60000]\n",
"loss: 0.567288 [ 6464/60000]\n",
"loss: 0.374094 [12864/60000]\n",
"loss: 0.609605 [19264/60000]\n",
"loss: 0.557286 [25664/60000]\n",
"loss: 0.546430 [32064/60000]\n",
"loss: 0.545452 [38464/60000]\n",
"loss: 0.675633 [44864/60000]\n",
"loss: 0.625234 [51264/60000]\n",
"loss: 0.511247 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.1%, Avg loss: 0.541623 \n",
"\n",
"Epoch 32\n",
"-------------------------------\n",
"loss: 0.439007 [ 64/60000]\n",
"loss: 0.561025 [ 6464/60000]\n",
"loss: 0.369243 [12864/60000]\n",
"loss: 0.604281 [19264/60000]\n",
"loss: 0.551683 [25664/60000]\n",
"loss: 0.542216 [32064/60000]\n",
"loss: 0.539204 [38464/60000]\n",
"loss: 0.675125 [44864/60000]\n",
"loss: 0.623876 [51264/60000]\n",
"loss: 0.504428 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.2%, Avg loss: 0.537320 \n",
"\n",
"Epoch 33\n",
"-------------------------------\n",
"loss: 0.432186 [ 64/60000]\n",
"loss: 0.555212 [ 6464/60000]\n",
"loss: 0.364681 [12864/60000]\n",
"loss: 0.599194 [19264/60000]\n",
"loss: 0.546189 [25664/60000]\n",
"loss: 0.537963 [32064/60000]\n",
"loss: 0.533403 [38464/60000]\n",
"loss: 0.674624 [44864/60000]\n",
"loss: 0.622553 [51264/60000]\n",
"loss: 0.497925 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.4%, Avg loss: 0.533272 \n",
"\n",
"Epoch 34\n",
"-------------------------------\n",
"loss: 0.425672 [ 64/60000]\n",
"loss: 0.549768 [ 6464/60000]\n",
"loss: 0.360433 [12864/60000]\n",
"loss: 0.594322 [19264/60000]\n",
"loss: 0.540800 [25664/60000]\n",
"loss: 0.533762 [32064/60000]\n",
"loss: 0.527987 [38464/60000]\n",
"loss: 0.674103 [44864/60000]\n",
"loss: 0.621166 [51264/60000]\n",
"loss: 0.491655 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.5%, Avg loss: 0.529465 \n",
"\n",
"Epoch 35\n",
"-------------------------------\n",
"loss: 0.419437 [ 64/60000]\n",
"loss: 0.544650 [ 6464/60000]\n",
"loss: 0.356402 [12864/60000]\n",
"loss: 0.589605 [19264/60000]\n",
"loss: 0.535532 [25664/60000]\n",
"loss: 0.529567 [32064/60000]\n",
"loss: 0.522878 [38464/60000]\n",
"loss: 0.673613 [44864/60000]\n",
"loss: 0.619837 [51264/60000]\n",
"loss: 0.485710 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.6%, Avg loss: 0.525883 \n",
"\n",
"Epoch 36\n",
"-------------------------------\n",
"loss: 0.413498 [ 64/60000]\n",
"loss: 0.539885 [ 6464/60000]\n",
"loss: 0.352539 [12864/60000]\n",
"loss: 0.585062 [19264/60000]\n",
"loss: 0.530344 [25664/60000]\n",
"loss: 0.525445 [32064/60000]\n",
"loss: 0.518064 [38464/60000]\n",
"loss: 0.673092 [44864/60000]\n",
"loss: 0.618495 [51264/60000]\n",
"loss: 0.480051 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.6%, Avg loss: 0.522500 \n",
"\n",
"Epoch 37\n",
"-------------------------------\n",
"loss: 0.407799 [ 64/60000]\n",
"loss: 0.535391 [ 6464/60000]\n",
"loss: 0.348870 [12864/60000]\n",
"loss: 0.580659 [19264/60000]\n",
"loss: 0.525235 [25664/60000]\n",
"loss: 0.521353 [32064/60000]\n",
"loss: 0.513480 [38464/60000]\n",
"loss: 0.672483 [44864/60000]\n",
"loss: 0.617048 [51264/60000]\n",
"loss: 0.474671 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.7%, Avg loss: 0.519301 \n",
"\n",
"Epoch 38\n",
"-------------------------------\n",
"loss: 0.402323 [ 64/60000]\n",
"loss: 0.531171 [ 6464/60000]\n",
"loss: 0.345384 [12864/60000]\n",
"loss: 0.576379 [19264/60000]\n",
"loss: 0.520239 [25664/60000]\n",
"loss: 0.517299 [32064/60000]\n",
"loss: 0.509134 [38464/60000]\n",
"loss: 0.671813 [44864/60000]\n",
"loss: 0.615612 [51264/60000]\n",
"loss: 0.469559 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.9%, Avg loss: 0.516264 \n",
"\n",
"Epoch 39\n",
"-------------------------------\n",
"loss: 0.397061 [ 64/60000]\n",
"loss: 0.527188 [ 6464/60000]\n",
"loss: 0.342034 [12864/60000]\n",
"loss: 0.572297 [19264/60000]\n",
"loss: 0.515315 [25664/60000]\n",
"loss: 0.513353 [32064/60000]\n",
"loss: 0.505046 [38464/60000]\n",
"loss: 0.671109 [44864/60000]\n",
"loss: 0.614133 [51264/60000]\n",
"loss: 0.464718 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 81.9%, Avg loss: 0.513386 \n",
"\n",
"Epoch 40\n",
"-------------------------------\n",
"loss: 0.391973 [ 64/60000]\n",
"loss: 0.523437 [ 6464/60000]\n",
"loss: 0.338854 [12864/60000]\n",
"loss: 0.568338 [19264/60000]\n",
"loss: 0.510549 [25664/60000]\n",
"loss: 0.509420 [32064/60000]\n",
"loss: 0.501144 [38464/60000]\n",
"loss: 0.670262 [44864/60000]\n",
"loss: 0.612625 [51264/60000]\n",
"loss: 0.460128 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.0%, Avg loss: 0.510651 \n",
"\n",
"Epoch 41\n",
"-------------------------------\n",
"loss: 0.387064 [ 64/60000]\n",
"loss: 0.519877 [ 6464/60000]\n",
"loss: 0.335795 [12864/60000]\n",
"loss: 0.564523 [19264/60000]\n",
"loss: 0.505914 [25664/60000]\n",
"loss: 0.505606 [32064/60000]\n",
"loss: 0.497429 [38464/60000]\n",
"loss: 0.669315 [44864/60000]\n",
"loss: 0.611076 [51264/60000]\n",
"loss: 0.455752 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.0%, Avg loss: 0.508043 \n",
"\n",
"Epoch 42\n",
"-------------------------------\n",
"loss: 0.382324 [ 64/60000]\n",
"loss: 0.516514 [ 6464/60000]\n",
"loss: 0.332889 [12864/60000]\n",
"loss: 0.560811 [19264/60000]\n",
"loss: 0.501391 [25664/60000]\n",
"loss: 0.501844 [32064/60000]\n",
"loss: 0.493890 [38464/60000]\n",
"loss: 0.668269 [44864/60000]\n",
"loss: 0.609468 [51264/60000]\n",
"loss: 0.451606 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.1%, Avg loss: 0.505556 \n",
"\n",
"Epoch 43\n",
"-------------------------------\n",
"loss: 0.377740 [ 64/60000]\n",
"loss: 0.513336 [ 6464/60000]\n",
"loss: 0.330095 [12864/60000]\n",
"loss: 0.557230 [19264/60000]\n",
"loss: 0.496991 [25664/60000]\n",
"loss: 0.498149 [32064/60000]\n",
"loss: 0.490533 [38464/60000]\n",
"loss: 0.667148 [44864/60000]\n",
"loss: 0.607838 [51264/60000]\n",
"loss: 0.447620 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.2%, Avg loss: 0.503179 \n",
"\n",
"Epoch 44\n",
"-------------------------------\n",
"loss: 0.373325 [ 64/60000]\n",
"loss: 0.510294 [ 6464/60000]\n",
"loss: 0.327440 [12864/60000]\n",
"loss: 0.553777 [19264/60000]\n",
"loss: 0.492716 [25664/60000]\n",
"loss: 0.494609 [32064/60000]\n",
"loss: 0.487340 [38464/60000]\n",
"loss: 0.665944 [44864/60000]\n",
"loss: 0.606142 [51264/60000]\n",
"loss: 0.443893 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.3%, Avg loss: 0.500906 \n",
"\n",
"Epoch 45\n",
"-------------------------------\n",
"loss: 0.369023 [ 64/60000]\n",
"loss: 0.507392 [ 6464/60000]\n",
"loss: 0.324879 [12864/60000]\n",
"loss: 0.550461 [19264/60000]\n",
"loss: 0.488557 [25664/60000]\n",
"loss: 0.491236 [32064/60000]\n",
"loss: 0.484261 [38464/60000]\n",
"loss: 0.664669 [44864/60000]\n",
"loss: 0.604463 [51264/60000]\n",
"loss: 0.440351 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.4%, Avg loss: 0.498726 \n",
"\n",
"Epoch 46\n",
"-------------------------------\n",
"loss: 0.364832 [ 64/60000]\n",
"loss: 0.504612 [ 6464/60000]\n",
"loss: 0.322402 [12864/60000]\n",
"loss: 0.547232 [19264/60000]\n",
"loss: 0.484497 [25664/60000]\n",
"loss: 0.487927 [32064/60000]\n",
"loss: 0.481346 [38464/60000]\n",
"loss: 0.663265 [44864/60000]\n",
"loss: 0.602736 [51264/60000]\n",
"loss: 0.436972 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.5%, Avg loss: 0.496639 \n",
"\n",
"Epoch 47\n",
"-------------------------------\n",
"loss: 0.360754 [ 64/60000]\n",
"loss: 0.501936 [ 6464/60000]\n",
"loss: 0.319981 [12864/60000]\n",
"loss: 0.544117 [19264/60000]\n",
"loss: 0.480597 [25664/60000]\n",
"loss: 0.484709 [32064/60000]\n",
"loss: 0.478559 [38464/60000]\n",
"loss: 0.661760 [44864/60000]\n",
"loss: 0.600944 [51264/60000]\n",
"loss: 0.433797 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.6%, Avg loss: 0.494633 \n",
"\n",
"Epoch 48\n",
"-------------------------------\n",
"loss: 0.356818 [ 64/60000]\n",
"loss: 0.499409 [ 6464/60000]\n",
"loss: 0.317665 [12864/60000]\n",
"loss: 0.541060 [19264/60000]\n",
"loss: 0.476794 [25664/60000]\n",
"loss: 0.481562 [32064/60000]\n",
"loss: 0.475892 [38464/60000]\n",
"loss: 0.660203 [44864/60000]\n",
"loss: 0.599235 [51264/60000]\n",
"loss: 0.430792 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.7%, Avg loss: 0.492701 \n",
"\n",
"Epoch 49\n",
"-------------------------------\n",
"loss: 0.353056 [ 64/60000]\n",
"loss: 0.496988 [ 6464/60000]\n",
"loss: 0.315405 [12864/60000]\n",
"loss: 0.538113 [19264/60000]\n",
"loss: 0.473144 [25664/60000]\n",
"loss: 0.478582 [32064/60000]\n",
"loss: 0.473277 [38464/60000]\n",
"loss: 0.658563 [44864/60000]\n",
"loss: 0.597550 [51264/60000]\n",
"loss: 0.427926 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.7%, Avg loss: 0.490837 \n",
"\n",
"Epoch 50\n",
"-------------------------------\n",
"loss: 0.349353 [ 64/60000]\n",
"loss: 0.494682 [ 6464/60000]\n",
"loss: 0.313216 [12864/60000]\n",
"loss: 0.535247 [19264/60000]\n",
"loss: 0.469566 [25664/60000]\n",
"loss: 0.475692 [32064/60000]\n",
"loss: 0.470816 [38464/60000]\n",
"loss: 0.656758 [44864/60000]\n",
"loss: 0.595820 [51264/60000]\n",
"loss: 0.425227 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.8%, Avg loss: 0.489036 \n",
"\n",
"Epoch 51\n",
"-------------------------------\n",
"loss: 0.345757 [ 64/60000]\n",
"loss: 0.492491 [ 6464/60000]\n",
"loss: 0.311069 [12864/60000]\n",
"loss: 0.532528 [19264/60000]\n",
"loss: 0.466094 [25664/60000]\n",
"loss: 0.472902 [32064/60000]\n",
"loss: 0.468441 [38464/60000]\n",
"loss: 0.654984 [44864/60000]\n",
"loss: 0.594080 [51264/60000]\n",
"loss: 0.422612 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.8%, Avg loss: 0.487296 \n",
"\n",
"Epoch 52\n",
"-------------------------------\n",
"loss: 0.342263 [ 64/60000]\n",
"loss: 0.490367 [ 6464/60000]\n",
"loss: 0.308987 [12864/60000]\n",
"loss: 0.529865 [19264/60000]\n",
"loss: 0.462766 [25664/60000]\n",
"loss: 0.470273 [32064/60000]\n",
"loss: 0.466082 [38464/60000]\n",
"loss: 0.653179 [44864/60000]\n",
"loss: 0.592359 [51264/60000]\n",
"loss: 0.420103 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.8%, Avg loss: 0.485619 \n",
"\n",
"Epoch 53\n",
"-------------------------------\n",
"loss: 0.338853 [ 64/60000]\n",
"loss: 0.488292 [ 6464/60000]\n",
"loss: 0.306925 [12864/60000]\n",
"loss: 0.527329 [19264/60000]\n",
"loss: 0.459472 [25664/60000]\n",
"loss: 0.467760 [32064/60000]\n",
"loss: 0.463849 [38464/60000]\n",
"loss: 0.651313 [44864/60000]\n",
"loss: 0.590723 [51264/60000]\n",
"loss: 0.417756 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 82.9%, Avg loss: 0.483996 \n",
"\n",
"Epoch 54\n",
"-------------------------------\n",
"loss: 0.335551 [ 64/60000]\n",
"loss: 0.486273 [ 6464/60000]\n",
"loss: 0.304979 [12864/60000]\n",
"loss: 0.524847 [19264/60000]\n",
"loss: 0.456233 [25664/60000]\n",
"loss: 0.465289 [32064/60000]\n",
"loss: 0.461765 [38464/60000]\n",
"loss: 0.649427 [44864/60000]\n",
"loss: 0.589112 [51264/60000]\n",
"loss: 0.415524 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.0%, Avg loss: 0.482419 \n",
"\n",
"Epoch 55\n",
"-------------------------------\n",
"loss: 0.332382 [ 64/60000]\n",
"loss: 0.484308 [ 6464/60000]\n",
"loss: 0.303058 [12864/60000]\n",
"loss: 0.522521 [19264/60000]\n",
"loss: 0.453092 [25664/60000]\n",
"loss: 0.462991 [32064/60000]\n",
"loss: 0.459777 [38464/60000]\n",
"loss: 0.647514 [44864/60000]\n",
"loss: 0.587518 [51264/60000]\n",
"loss: 0.413467 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.1%, Avg loss: 0.480894 \n",
"\n",
"Epoch 56\n",
"-------------------------------\n",
"loss: 0.329310 [ 64/60000]\n",
"loss: 0.482378 [ 6464/60000]\n",
"loss: 0.301225 [12864/60000]\n",
"loss: 0.520266 [19264/60000]\n",
"loss: 0.450069 [25664/60000]\n",
"loss: 0.460753 [32064/60000]\n",
"loss: 0.457853 [38464/60000]\n",
"loss: 0.645581 [44864/60000]\n",
"loss: 0.585949 [51264/60000]\n",
"loss: 0.411519 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.1%, Avg loss: 0.479414 \n",
"\n",
"Epoch 57\n",
"-------------------------------\n",
"loss: 0.326321 [ 64/60000]\n",
"loss: 0.480451 [ 6464/60000]\n",
"loss: 0.299462 [12864/60000]\n",
"loss: 0.518094 [19264/60000]\n",
"loss: 0.447126 [25664/60000]\n",
"loss: 0.458697 [32064/60000]\n",
"loss: 0.455995 [38464/60000]\n",
"loss: 0.643693 [44864/60000]\n",
"loss: 0.584398 [51264/60000]\n",
"loss: 0.409660 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.1%, Avg loss: 0.477979 \n",
"\n",
"Epoch 58\n",
"-------------------------------\n",
"loss: 0.323418 [ 64/60000]\n",
"loss: 0.478570 [ 6464/60000]\n",
"loss: 0.297740 [12864/60000]\n",
"loss: 0.515970 [19264/60000]\n",
"loss: 0.444312 [25664/60000]\n",
"loss: 0.456664 [32064/60000]\n",
"loss: 0.454142 [38464/60000]\n",
"loss: 0.641751 [44864/60000]\n",
"loss: 0.582845 [51264/60000]\n",
"loss: 0.407896 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.1%, Avg loss: 0.476582 \n",
"\n",
"Epoch 59\n",
"-------------------------------\n",
"loss: 0.320634 [ 64/60000]\n",
"loss: 0.476725 [ 6464/60000]\n",
"loss: 0.296042 [12864/60000]\n",
"loss: 0.513921 [19264/60000]\n",
"loss: 0.441596 [25664/60000]\n",
"loss: 0.454676 [32064/60000]\n",
"loss: 0.452355 [38464/60000]\n",
"loss: 0.639842 [44864/60000]\n",
"loss: 0.581317 [51264/60000]\n",
"loss: 0.406206 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.1%, Avg loss: 0.475222 \n",
"\n",
"Epoch 60\n",
"-------------------------------\n",
"loss: 0.317941 [ 64/60000]\n",
"loss: 0.474918 [ 6464/60000]\n",
"loss: 0.294424 [12864/60000]\n",
"loss: 0.511960 [19264/60000]\n",
"loss: 0.438956 [25664/60000]\n",
"loss: 0.452752 [32064/60000]\n",
"loss: 0.450616 [38464/60000]\n",
"loss: 0.637881 [44864/60000]\n",
"loss: 0.579778 [51264/60000]\n",
"loss: 0.404595 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.2%, Avg loss: 0.473891 \n",
"\n",
"Epoch 61\n",
"-------------------------------\n",
"loss: 0.315354 [ 64/60000]\n",
"loss: 0.473169 [ 6464/60000]\n",
"loss: 0.292859 [12864/60000]\n",
"loss: 0.510092 [19264/60000]\n",
"loss: 0.436366 [25664/60000]\n",
"loss: 0.450889 [32064/60000]\n",
"loss: 0.448871 [38464/60000]\n",
"loss: 0.635946 [44864/60000]\n",
"loss: 0.578202 [51264/60000]\n",
"loss: 0.403056 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.3%, Avg loss: 0.472591 \n",
"\n",
"Epoch 62\n",
"-------------------------------\n",
"loss: 0.312814 [ 64/60000]\n",
"loss: 0.471399 [ 6464/60000]\n",
"loss: 0.291309 [12864/60000]\n",
"loss: 0.508244 [19264/60000]\n",
"loss: 0.433882 [25664/60000]\n",
"loss: 0.449061 [32064/60000]\n",
"loss: 0.447202 [38464/60000]\n",
"loss: 0.634043 [44864/60000]\n",
"loss: 0.576671 [51264/60000]\n",
"loss: 0.401604 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.4%, Avg loss: 0.471328 \n",
"\n",
"Epoch 63\n",
"-------------------------------\n",
"loss: 0.310327 [ 64/60000]\n",
"loss: 0.469621 [ 6464/60000]\n",
"loss: 0.289869 [12864/60000]\n",
"loss: 0.506510 [19264/60000]\n",
"loss: 0.431465 [25664/60000]\n",
"loss: 0.447359 [32064/60000]\n",
"loss: 0.445648 [38464/60000]\n",
"loss: 0.632109 [44864/60000]\n",
"loss: 0.575194 [51264/60000]\n",
"loss: 0.400220 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.4%, Avg loss: 0.470092 \n",
"\n",
"Epoch 64\n",
"-------------------------------\n",
"loss: 0.307934 [ 64/60000]\n",
"loss: 0.467874 [ 6464/60000]\n",
"loss: 0.288371 [12864/60000]\n",
"loss: 0.504769 [19264/60000]\n",
"loss: 0.429059 [25664/60000]\n",
"loss: 0.445699 [32064/60000]\n",
"loss: 0.444159 [38464/60000]\n",
"loss: 0.630172 [44864/60000]\n",
"loss: 0.573773 [51264/60000]\n",
"loss: 0.398884 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.4%, Avg loss: 0.468888 \n",
"\n",
"Epoch 65\n",
"-------------------------------\n",
"loss: 0.305590 [ 64/60000]\n",
"loss: 0.466199 [ 6464/60000]\n",
"loss: 0.286952 [12864/60000]\n",
"loss: 0.503127 [19264/60000]\n",
"loss: 0.426700 [25664/60000]\n",
"loss: 0.444055 [32064/60000]\n",
"loss: 0.442674 [38464/60000]\n",
"loss: 0.628258 [44864/60000]\n",
"loss: 0.572357 [51264/60000]\n",
"loss: 0.397590 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.5%, Avg loss: 0.467707 \n",
"\n",
"Epoch 66\n",
"-------------------------------\n",
"loss: 0.303334 [ 64/60000]\n",
"loss: 0.464632 [ 6464/60000]\n",
"loss: 0.285564 [12864/60000]\n",
"loss: 0.501573 [19264/60000]\n",
"loss: 0.424368 [25664/60000]\n",
"loss: 0.442408 [32064/60000]\n",
"loss: 0.441304 [38464/60000]\n",
"loss: 0.626372 [44864/60000]\n",
"loss: 0.570899 [51264/60000]\n",
"loss: 0.396384 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.5%, Avg loss: 0.466553 \n",
"\n",
"Epoch 67\n",
"-------------------------------\n",
"loss: 0.301153 [ 64/60000]\n",
"loss: 0.463068 [ 6464/60000]\n",
"loss: 0.284238 [12864/60000]\n",
"loss: 0.500059 [19264/60000]\n",
"loss: 0.422121 [25664/60000]\n",
"loss: 0.440843 [32064/60000]\n",
"loss: 0.439893 [38464/60000]\n",
"loss: 0.624485 [44864/60000]\n",
"loss: 0.569421 [51264/60000]\n",
"loss: 0.395213 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.5%, Avg loss: 0.465425 \n",
"\n",
"Epoch 68\n",
"-------------------------------\n",
"loss: 0.299054 [ 64/60000]\n",
"loss: 0.461508 [ 6464/60000]\n",
"loss: 0.282930 [12864/60000]\n",
"loss: 0.498558 [19264/60000]\n",
"loss: 0.419881 [25664/60000]\n",
"loss: 0.439315 [32064/60000]\n",
"loss: 0.438488 [38464/60000]\n",
"loss: 0.622598 [44864/60000]\n",
"loss: 0.567905 [51264/60000]\n",
"loss: 0.394085 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.5%, Avg loss: 0.464326 \n",
"\n",
"Epoch 69\n",
"-------------------------------\n",
"loss: 0.297019 [ 64/60000]\n",
"loss: 0.460005 [ 6464/60000]\n",
"loss: 0.281680 [12864/60000]\n",
"loss: 0.497166 [19264/60000]\n",
"loss: 0.417704 [25664/60000]\n",
"loss: 0.437840 [32064/60000]\n",
"loss: 0.437086 [38464/60000]\n",
"loss: 0.620670 [44864/60000]\n",
"loss: 0.566433 [51264/60000]\n",
"loss: 0.392990 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.5%, Avg loss: 0.463251 \n",
"\n",
"Epoch 70\n",
"-------------------------------\n",
"loss: 0.295058 [ 64/60000]\n",
"loss: 0.458569 [ 6464/60000]\n",
"loss: 0.280487 [12864/60000]\n",
"loss: 0.495762 [19264/60000]\n",
"loss: 0.415654 [25664/60000]\n",
"loss: 0.436464 [32064/60000]\n",
"loss: 0.435709 [38464/60000]\n",
"loss: 0.618734 [44864/60000]\n",
"loss: 0.564913 [51264/60000]\n",
"loss: 0.391943 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.6%, Avg loss: 0.462199 \n",
"\n",
"Epoch 71\n",
"-------------------------------\n",
"loss: 0.293159 [ 64/60000]\n",
"loss: 0.457152 [ 6464/60000]\n",
"loss: 0.279345 [12864/60000]\n",
"loss: 0.494372 [19264/60000]\n",
"loss: 0.413608 [25664/60000]\n",
"loss: 0.435065 [32064/60000]\n",
"loss: 0.434385 [38464/60000]\n",
"loss: 0.616807 [44864/60000]\n",
"loss: 0.563371 [51264/60000]\n",
"loss: 0.390915 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.6%, Avg loss: 0.461166 \n",
"\n",
"Epoch 72\n",
"-------------------------------\n",
"loss: 0.291307 [ 64/60000]\n",
"loss: 0.455769 [ 6464/60000]\n",
"loss: 0.278250 [12864/60000]\n",
"loss: 0.493022 [19264/60000]\n",
"loss: 0.411582 [25664/60000]\n",
"loss: 0.433743 [32064/60000]\n",
"loss: 0.433103 [38464/60000]\n",
"loss: 0.614904 [44864/60000]\n",
"loss: 0.561844 [51264/60000]\n",
"loss: 0.389942 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.7%, Avg loss: 0.460157 \n",
"\n",
"Epoch 73\n",
"-------------------------------\n",
"loss: 0.289506 [ 64/60000]\n",
"loss: 0.454390 [ 6464/60000]\n",
"loss: 0.277180 [12864/60000]\n",
"loss: 0.491688 [19264/60000]\n",
"loss: 0.409602 [25664/60000]\n",
"loss: 0.432444 [32064/60000]\n",
"loss: 0.431869 [38464/60000]\n",
"loss: 0.613108 [44864/60000]\n",
"loss: 0.560340 [51264/60000]\n",
"loss: 0.388964 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.7%, Avg loss: 0.459169 \n",
"\n",
"Epoch 74\n",
"-------------------------------\n",
"loss: 0.287768 [ 64/60000]\n",
"loss: 0.452984 [ 6464/60000]\n",
"loss: 0.276152 [12864/60000]\n",
"loss: 0.490388 [19264/60000]\n",
"loss: 0.407663 [25664/60000]\n",
"loss: 0.431170 [32064/60000]\n",
"loss: 0.430620 [38464/60000]\n",
"loss: 0.611292 [44864/60000]\n",
"loss: 0.558917 [51264/60000]\n",
"loss: 0.387971 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.7%, Avg loss: 0.458198 \n",
"\n",
"Epoch 75\n",
"-------------------------------\n",
"loss: 0.286065 [ 64/60000]\n",
"loss: 0.451654 [ 6464/60000]\n",
"loss: 0.275135 [12864/60000]\n",
"loss: 0.489104 [19264/60000]\n",
"loss: 0.405743 [25664/60000]\n",
"loss: 0.429890 [32064/60000]\n",
"loss: 0.429387 [38464/60000]\n",
"loss: 0.609502 [44864/60000]\n",
"loss: 0.557519 [51264/60000]\n",
"loss: 0.387040 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.7%, Avg loss: 0.457236 \n",
"\n",
"Epoch 76\n",
"-------------------------------\n",
"loss: 0.284433 [ 64/60000]\n",
"loss: 0.450397 [ 6464/60000]\n",
"loss: 0.274165 [12864/60000]\n",
"loss: 0.487882 [19264/60000]\n",
"loss: 0.403845 [25664/60000]\n",
"loss: 0.428636 [32064/60000]\n",
"loss: 0.428182 [38464/60000]\n",
"loss: 0.607730 [44864/60000]\n",
"loss: 0.556097 [51264/60000]\n",
"loss: 0.386184 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.8%, Avg loss: 0.456292 \n",
"\n",
"Epoch 77\n",
"-------------------------------\n",
"loss: 0.282841 [ 64/60000]\n",
"loss: 0.449110 [ 6464/60000]\n",
"loss: 0.273197 [12864/60000]\n",
"loss: 0.486648 [19264/60000]\n",
"loss: 0.401883 [25664/60000]\n",
"loss: 0.427384 [32064/60000]\n",
"loss: 0.426992 [38464/60000]\n",
"loss: 0.605967 [44864/60000]\n",
"loss: 0.554665 [51264/60000]\n",
"loss: 0.385385 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.8%, Avg loss: 0.455365 \n",
"\n",
"Epoch 78\n",
"-------------------------------\n",
"loss: 0.281263 [ 64/60000]\n",
"loss: 0.447836 [ 6464/60000]\n",
"loss: 0.272210 [12864/60000]\n",
"loss: 0.485453 [19264/60000]\n",
"loss: 0.399975 [25664/60000]\n",
"loss: 0.426161 [32064/60000]\n",
"loss: 0.425830 [38464/60000]\n",
"loss: 0.604246 [44864/60000]\n",
"loss: 0.553228 [51264/60000]\n",
"loss: 0.384594 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.8%, Avg loss: 0.454451 \n",
"\n",
"Epoch 79\n",
"-------------------------------\n",
"loss: 0.279766 [ 64/60000]\n",
"loss: 0.446586 [ 6464/60000]\n",
"loss: 0.271286 [12864/60000]\n",
"loss: 0.484287 [19264/60000]\n",
"loss: 0.398106 [25664/60000]\n",
"loss: 0.425061 [32064/60000]\n",
"loss: 0.424674 [38464/60000]\n",
"loss: 0.602439 [44864/60000]\n",
"loss: 0.551811 [51264/60000]\n",
"loss: 0.383852 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.8%, Avg loss: 0.453564 \n",
"\n",
"Epoch 80\n",
"-------------------------------\n",
"loss: 0.278299 [ 64/60000]\n",
"loss: 0.445358 [ 6464/60000]\n",
"loss: 0.270351 [12864/60000]\n",
"loss: 0.483141 [19264/60000]\n",
"loss: 0.396297 [25664/60000]\n",
"loss: 0.423920 [32064/60000]\n",
"loss: 0.423496 [38464/60000]\n",
"loss: 0.600713 [44864/60000]\n",
"loss: 0.550473 [51264/60000]\n",
"loss: 0.383113 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.8%, Avg loss: 0.452678 \n",
"\n",
"Epoch 81\n",
"-------------------------------\n",
"loss: 0.276881 [ 64/60000]\n",
"loss: 0.444190 [ 6464/60000]\n",
"loss: 0.269415 [12864/60000]\n",
"loss: 0.482046 [19264/60000]\n",
"loss: 0.394519 [25664/60000]\n",
"loss: 0.422776 [32064/60000]\n",
"loss: 0.422323 [38464/60000]\n",
"loss: 0.599012 [44864/60000]\n",
"loss: 0.549099 [51264/60000]\n",
"loss: 0.382355 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.8%, Avg loss: 0.451804 \n",
"\n",
"Epoch 82\n",
"-------------------------------\n",
"loss: 0.275486 [ 64/60000]\n",
"loss: 0.442999 [ 6464/60000]\n",
"loss: 0.268485 [12864/60000]\n",
"loss: 0.480958 [19264/60000]\n",
"loss: 0.392827 [25664/60000]\n",
"loss: 0.421650 [32064/60000]\n",
"loss: 0.421187 [38464/60000]\n",
"loss: 0.597310 [44864/60000]\n",
"loss: 0.547844 [51264/60000]\n",
"loss: 0.381603 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.9%, Avg loss: 0.450946 \n",
"\n",
"Epoch 83\n",
"-------------------------------\n",
"loss: 0.274145 [ 64/60000]\n",
"loss: 0.441848 [ 6464/60000]\n",
"loss: 0.267610 [12864/60000]\n",
"loss: 0.479846 [19264/60000]\n",
"loss: 0.391081 [25664/60000]\n",
"loss: 0.420557 [32064/60000]\n",
"loss: 0.420084 [38464/60000]\n",
"loss: 0.595690 [44864/60000]\n",
"loss: 0.546578 [51264/60000]\n",
"loss: 0.380925 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.9%, Avg loss: 0.450101 \n",
"\n",
"Epoch 84\n",
"-------------------------------\n",
"loss: 0.272871 [ 64/60000]\n",
"loss: 0.440700 [ 6464/60000]\n",
"loss: 0.266744 [12864/60000]\n",
"loss: 0.478726 [19264/60000]\n",
"loss: 0.389383 [25664/60000]\n",
"loss: 0.419525 [32064/60000]\n",
"loss: 0.419045 [38464/60000]\n",
"loss: 0.594137 [44864/60000]\n",
"loss: 0.545326 [51264/60000]\n",
"loss: 0.380205 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 83.9%, Avg loss: 0.449273 \n",
"\n",
"Epoch 85\n",
"-------------------------------\n",
"loss: 0.271628 [ 64/60000]\n",
"loss: 0.439552 [ 6464/60000]\n",
"loss: 0.265876 [12864/60000]\n",
"loss: 0.477599 [19264/60000]\n",
"loss: 0.387731 [25664/60000]\n",
"loss: 0.418448 [32064/60000]\n",
"loss: 0.417979 [38464/60000]\n",
"loss: 0.592620 [44864/60000]\n",
"loss: 0.544029 [51264/60000]\n",
"loss: 0.379485 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.0%, Avg loss: 0.448453 \n",
"\n",
"Epoch 86\n",
"-------------------------------\n",
"loss: 0.270426 [ 64/60000]\n",
"loss: 0.438424 [ 6464/60000]\n",
"loss: 0.265030 [12864/60000]\n",
"loss: 0.476489 [19264/60000]\n",
"loss: 0.386051 [25664/60000]\n",
"loss: 0.417407 [32064/60000]\n",
"loss: 0.416970 [38464/60000]\n",
"loss: 0.591168 [44864/60000]\n",
"loss: 0.542710 [51264/60000]\n",
"loss: 0.378828 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.0%, Avg loss: 0.447647 \n",
"\n",
"Epoch 87\n",
"-------------------------------\n",
"loss: 0.269208 [ 64/60000]\n",
"loss: 0.437199 [ 6464/60000]\n",
"loss: 0.264230 [12864/60000]\n",
"loss: 0.475392 [19264/60000]\n",
"loss: 0.384515 [25664/60000]\n",
"loss: 0.416350 [32064/60000]\n",
"loss: 0.415911 [38464/60000]\n",
"loss: 0.589722 [44864/60000]\n",
"loss: 0.541381 [51264/60000]\n",
"loss: 0.378159 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.0%, Avg loss: 0.446851 \n",
"\n",
"Epoch 88\n",
"-------------------------------\n",
"loss: 0.268064 [ 64/60000]\n",
"loss: 0.436019 [ 6464/60000]\n",
"loss: 0.263430 [12864/60000]\n",
"loss: 0.474290 [19264/60000]\n",
"loss: 0.383011 [25664/60000]\n",
"loss: 0.415293 [32064/60000]\n",
"loss: 0.414872 [38464/60000]\n",
"loss: 0.588241 [44864/60000]\n",
"loss: 0.540016 [51264/60000]\n",
"loss: 0.377488 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.1%, Avg loss: 0.446069 \n",
"\n",
"Epoch 89\n",
"-------------------------------\n",
"loss: 0.266918 [ 64/60000]\n",
"loss: 0.434826 [ 6464/60000]\n",
"loss: 0.262644 [12864/60000]\n",
"loss: 0.473225 [19264/60000]\n",
"loss: 0.381503 [25664/60000]\n",
"loss: 0.414239 [32064/60000]\n",
"loss: 0.413861 [38464/60000]\n",
"loss: 0.586801 [44864/60000]\n",
"loss: 0.538736 [51264/60000]\n",
"loss: 0.376820 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.1%, Avg loss: 0.445294 \n",
"\n",
"Epoch 90\n",
"-------------------------------\n",
"loss: 0.265819 [ 64/60000]\n",
"loss: 0.433672 [ 6464/60000]\n",
"loss: 0.261907 [12864/60000]\n",
"loss: 0.472169 [19264/60000]\n",
"loss: 0.380001 [25664/60000]\n",
"loss: 0.413226 [32064/60000]\n",
"loss: 0.412902 [38464/60000]\n",
"loss: 0.585358 [44864/60000]\n",
"loss: 0.537522 [51264/60000]\n",
"loss: 0.376200 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.1%, Avg loss: 0.444530 \n",
"\n",
"Epoch 91\n",
"-------------------------------\n",
"loss: 0.264763 [ 64/60000]\n",
"loss: 0.432531 [ 6464/60000]\n",
"loss: 0.261169 [12864/60000]\n",
"loss: 0.471114 [19264/60000]\n",
"loss: 0.378527 [25664/60000]\n",
"loss: 0.412258 [32064/60000]\n",
"loss: 0.411938 [38464/60000]\n",
"loss: 0.583972 [44864/60000]\n",
"loss: 0.536335 [51264/60000]\n",
"loss: 0.375551 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.1%, Avg loss: 0.443776 \n",
"\n",
"Epoch 92\n",
"-------------------------------\n",
"loss: 0.263738 [ 64/60000]\n",
"loss: 0.431407 [ 6464/60000]\n",
"loss: 0.260475 [12864/60000]\n",
"loss: 0.470076 [19264/60000]\n",
"loss: 0.377092 [25664/60000]\n",
"loss: 0.411267 [32064/60000]\n",
"loss: 0.411031 [38464/60000]\n",
"loss: 0.582597 [44864/60000]\n",
"loss: 0.535108 [51264/60000]\n",
"loss: 0.374957 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.1%, Avg loss: 0.443027 \n",
"\n",
"Epoch 93\n",
"-------------------------------\n",
"loss: 0.262753 [ 64/60000]\n",
"loss: 0.430249 [ 6464/60000]\n",
"loss: 0.259790 [12864/60000]\n",
"loss: 0.469021 [19264/60000]\n",
"loss: 0.375670 [25664/60000]\n",
"loss: 0.410296 [32064/60000]\n",
"loss: 0.410095 [38464/60000]\n",
"loss: 0.581194 [44864/60000]\n",
"loss: 0.533844 [51264/60000]\n",
"loss: 0.374435 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.1%, Avg loss: 0.442285 \n",
"\n",
"Epoch 94\n",
"-------------------------------\n",
"loss: 0.261779 [ 64/60000]\n",
"loss: 0.429117 [ 6464/60000]\n",
"loss: 0.259063 [12864/60000]\n",
"loss: 0.467997 [19264/60000]\n",
"loss: 0.374277 [25664/60000]\n",
"loss: 0.409352 [32064/60000]\n",
"loss: 0.409260 [38464/60000]\n",
"loss: 0.579793 [44864/60000]\n",
"loss: 0.532563 [51264/60000]\n",
"loss: 0.373894 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.2%, Avg loss: 0.441553 \n",
"\n",
"Epoch 95\n",
"-------------------------------\n",
"loss: 0.260850 [ 64/60000]\n",
"loss: 0.427904 [ 6464/60000]\n",
"loss: 0.258308 [12864/60000]\n",
"loss: 0.466945 [19264/60000]\n",
"loss: 0.372853 [25664/60000]\n",
"loss: 0.408465 [32064/60000]\n",
"loss: 0.408386 [38464/60000]\n",
"loss: 0.578515 [44864/60000]\n",
"loss: 0.531247 [51264/60000]\n",
"loss: 0.373333 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.2%, Avg loss: 0.440825 \n",
"\n",
"Epoch 96\n",
"-------------------------------\n",
"loss: 0.259948 [ 64/60000]\n",
"loss: 0.426755 [ 6464/60000]\n",
"loss: 0.257635 [12864/60000]\n",
"loss: 0.465871 [19264/60000]\n",
"loss: 0.371553 [25664/60000]\n",
"loss: 0.407616 [32064/60000]\n",
"loss: 0.407448 [38464/60000]\n",
"loss: 0.577212 [44864/60000]\n",
"loss: 0.530006 [51264/60000]\n",
"loss: 0.372752 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.2%, Avg loss: 0.440108 \n",
"\n",
"Epoch 97\n",
"-------------------------------\n",
"loss: 0.259065 [ 64/60000]\n",
"loss: 0.425620 [ 6464/60000]\n",
"loss: 0.256952 [12864/60000]\n",
"loss: 0.464839 [19264/60000]\n",
"loss: 0.370332 [25664/60000]\n",
"loss: 0.406737 [32064/60000]\n",
"loss: 0.406582 [38464/60000]\n",
"loss: 0.575959 [44864/60000]\n",
"loss: 0.528807 [51264/60000]\n",
"loss: 0.372216 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.2%, Avg loss: 0.439403 \n",
"\n",
"Epoch 98\n",
"-------------------------------\n",
"loss: 0.258211 [ 64/60000]\n",
"loss: 0.424507 [ 6464/60000]\n",
"loss: 0.256278 [12864/60000]\n",
"loss: 0.463804 [19264/60000]\n",
"loss: 0.369029 [25664/60000]\n",
"loss: 0.405877 [32064/60000]\n",
"loss: 0.405647 [38464/60000]\n",
"loss: 0.574729 [44864/60000]\n",
"loss: 0.527638 [51264/60000]\n",
"loss: 0.371645 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.3%, Avg loss: 0.438701 \n",
"\n",
"Epoch 99\n",
"-------------------------------\n",
"loss: 0.257358 [ 64/60000]\n",
"loss: 0.423397 [ 6464/60000]\n",
"loss: 0.255560 [12864/60000]\n",
"loss: 0.462754 [19264/60000]\n",
"loss: 0.367813 [25664/60000]\n",
"loss: 0.404977 [32064/60000]\n",
"loss: 0.404711 [38464/60000]\n",
"loss: 0.573509 [44864/60000]\n",
"loss: 0.526371 [51264/60000]\n",
"loss: 0.371045 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.3%, Avg loss: 0.437999 \n",
"\n",
"Epoch 100\n",
"-------------------------------\n",
"loss: 0.256548 [ 64/60000]\n",
"loss: 0.422351 [ 6464/60000]\n",
"loss: 0.254838 [12864/60000]\n",
"loss: 0.461698 [19264/60000]\n",
"loss: 0.366585 [25664/60000]\n",
"loss: 0.404191 [32064/60000]\n",
"loss: 0.403847 [38464/60000]\n",
"loss: 0.572284 [44864/60000]\n",
"loss: 0.524943 [51264/60000]\n",
"loss: 0.370223 [57664/60000]\n",
"Test Error: \n",
" Accuracy: 84.3%, Avg loss: 0.437296 \n",
"\n",
"Done!\n"
]
}
],
"source": [
"loss_fn = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
"\n",
"epochs = 100\n",
"for t in range(epochs):\n",
" print(f\"Epoch {t+1}\\n-------------------------------\")\n",
" train_loop(train_dataloader, model, loss_fn, optimizer)\n",
" test_loop(test_dataloader, model, loss_fn)\n",
"print(\"Done!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "duCB2pWscw5-"
},
"source": [
"Further Reading\n",
"===============\n",
"\n",
"- [Loss\n",
" Functions](https://pytorch.org/docs/stable/nn.html#loss-functions)\n",
"- [torch.optim](https://pytorch.org/docs/stable/optim.html)\n",
"- [Warmstart Training a\n",
" Model](https://pytorch.org/tutorials/recipes/recipes/warmstarting_model_using_parameters_from_a_different_model.html)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
},
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment