Created
February 13, 2021 05:05
-
-
Save z-a-f/8f285e45052a75c4920d2ee47b698fd4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "extraordinary-korean", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.7.1\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from torch import nn\n", | |
"\n", | |
"print(torch.__version__)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "developing-remains", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"data_path = os.path.expanduser('~/data')\n", | |
"\n", | |
"from torchvision.datasets import FashionMNIST\n", | |
"from torchvision import transforms\n", | |
"\n", | |
"transform = transforms.ToTensor()\n", | |
"\n", | |
"train_set = FashionMNIST(data_path, train=True, download=True, transform=transform)\n", | |
"test_set = FashionMNIST(data_path, train=False, download=True, transform=transform)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "running-pontiac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils.data import DataLoader\n", | |
"\n", | |
"batch_size = 2048\n", | |
"\n", | |
"train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)\n", | |
"test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "supported-dakota", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = nn.Sequential(\n", | |
" nn.Flatten(),\n", | |
" nn.Linear(28 * 28, 128),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(128, 10)\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "affiliated-beads", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch import optim\n", | |
"\n", | |
"optimizer = optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1e-07)\n", | |
"criterion = nn.CrossEntropyLoss()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "manufactured-bolivia", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"01 / 10 .............................. - loss: 1.2759 - accuracy: 63.2250%\n", | |
"02 / 10 .............................. - loss: 0.6633 - accuracy: 77.4350%\n", | |
"03 / 10 .............................. - loss: 0.5554 - accuracy: 81.3700%\n", | |
"04 / 10 .............................. - loss: 0.5054 - accuracy: 82.9217%\n", | |
"05 / 10 .............................. - loss: 0.4770 - accuracy: 83.6900%\n", | |
"06 / 10 .............................. - loss: 0.4547 - accuracy: 84.4750%\n", | |
"07 / 10 .............................. - loss: 0.4358 - accuracy: 85.1000%\n", | |
"08 / 10 .............................. - loss: 0.4213 - accuracy: 85.5683%\n", | |
"09 / 10 .............................. - loss: 0.4105 - accuracy: 85.8400%\n", | |
"10 / 10 .............................. - loss: 0.3982 - accuracy: 86.2833%\n" | |
] | |
} | |
], | |
"source": [ | |
"model.train()\n", | |
"\n", | |
"epochs = 10\n", | |
"\n", | |
"for epoch in range(epochs):\n", | |
" print(f'{epoch+1:0>2} / {epochs}', end=' ')\n", | |
" running_loss = 0.0\n", | |
" running_accuracy = 0.0\n", | |
" for imgs, tgts in train_loader:\n", | |
" tgts_logits = model(imgs)\n", | |
" loss = criterion(tgts_logits, tgts)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" optimizer.zero_grad()\n", | |
" \n", | |
" predictions = tgts_logits.argmax(axis=1)\n", | |
" accuracy = (predictions == tgts).float().mean()\n", | |
" \n", | |
" running_loss += loss.item() * len(imgs)\n", | |
" running_accuracy += accuracy.item() * len(imgs)\n", | |
" \n", | |
" print('.', end='')\n", | |
"\n", | |
" running_loss /= len(train_set)\n", | |
" running_accuracy /= len(train_set)\n", | |
" print(f' - loss: {running_loss:.4f} - accuracy: {running_accuracy:.4%}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "reliable-tattoo", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"..... - loss: 0.4345 - accuracy: 84.8000%\n" | |
] | |
} | |
], | |
"source": [ | |
"model.eval()\n", | |
"\n", | |
"running_loss = 0.0\n", | |
"running_accuracy = 0.0\n", | |
"for imgs, tgts in test_loader:\n", | |
" tgts_logits = model(imgs)\n", | |
" loss = criterion(tgts_logits, tgts)\n", | |
"\n", | |
" predictions = tgts_logits.argmax(axis=1)\n", | |
" accuracy = (predictions == tgts).float().mean()\n", | |
"\n", | |
" running_loss += loss.item() * len(imgs)\n", | |
" running_accuracy += accuracy.item() * len(imgs)\n", | |
"\n", | |
" print('.', end='')\n", | |
"\n", | |
"running_loss /= len(test_set)\n", | |
"running_accuracy /= len(test_set)\n", | |
"print(f' - loss: {running_loss:.4f} - accuracy: {running_accuracy:.4%}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "compatible-jesus", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Tensorflow\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"model_tf = tf.keras.Sequential([\n", | |
" tf.keras.layers.Flatten(input_shape=(28, 28)),\n", | |
" tf.keras.layers.Dense(128, activation='relu'),\n", | |
" tf.keras.layers.Dense(10)\n", | |
"])\n", | |
"\n", | |
"model_tf.compile(optimizer='adam',\n", | |
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | |
" metrics=['accuracy'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "coupled-mapping", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.4929 - accuracy: 0.8268\n", | |
"Epoch 2/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.3727 - accuracy: 0.8655\n", | |
"Epoch 3/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.3349 - accuracy: 0.8787\n", | |
"Epoch 4/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.3106 - accuracy: 0.8857\n", | |
"Epoch 5/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2916 - accuracy: 0.8922\n", | |
"Epoch 6/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2772 - accuracy: 0.8976\n", | |
"Epoch 7/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2659 - accuracy: 0.9018\n", | |
"Epoch 8/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2557 - accuracy: 0.9046\n", | |
"Epoch 9/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2464 - accuracy: 0.9087\n", | |
"Epoch 10/10\n", | |
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2375 - accuracy: 0.9117\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<tensorflow.python.keras.callbacks.History at 0x7fa8482441d0>" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"imgs = train_set.data.numpy() / 255.0\n", | |
"lbls = train_set.targets.numpy()\n", | |
"\n", | |
"model_tf.fit(imgs, lbls, epochs=10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "compliant-thinking", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"313/313 - 0s - loss: 0.3318 - accuracy: 0.8857\n", | |
"\n", | |
"Test accuracy: 0.885699987411499\n" | |
] | |
} | |
], | |
"source": [ | |
"imgs = test_set.data.numpy() / 255.0\n", | |
"lbls = test_set.targets.numpy()\n", | |
"\n", | |
"test_loss, test_acc = model_tf.evaluate(imgs, lbls, verbose=2)\n", | |
"print('\\nTest accuracy:', test_acc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "subject-custom", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment