Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Curt-Park/2e7b9f2cea5aa259055e4696ee194717 to your computer and use it in GitHub Desktop.
Save Curt-Park/2e7b9f2cea5aa259055e4696ee194717 to your computer and use it in GitHub Desktop.
mnist_classification_with_two_layer_network.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNHvWTcRDiBYriS5KnRJVLK",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Curt-Park/2e7b9f2cea5aa259055e4696ee194717/mnist_classification_with_two_layer_network.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# MNIST Classification with Two-Layer Network"
],
"metadata": {
"id": "oy1HDOWNfeou"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Md3gftPufV7J",
"outputId": "ada234ed-fb4d-43fd-f619-81b94fca5b9c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: mnist in /usr/local/lib/python3.10/dist-packages (0.2.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from mnist) (1.25.2)\n"
]
}
],
"source": [
"!pip install mnist"
]
},
{
"cell_type": "code",
"source": [
"import mnist\n",
"\n",
"\n",
"train_images = mnist.train_images().reshape(-1, 28 * 28)\n",
"train_labels = mnist.train_labels()\n",
"\n",
"validation_images = train_images[:10000]\n",
"validation_labels = train_labels[:10000]\n",
"train_images = train_images[10000:]\n",
"train_labels = train_labels[10000:]\n",
"\n",
"test_images = mnist.test_images().reshape(-1, 28 * 28)\n",
"test_labels = mnist.test_labels()"
],
"metadata": {
"id": "FAee4ThqfsqF"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(\"train image shape:\", train_images.shape)\n",
"print(\"validation image shape:\", validation_images.shape)\n",
"print(\"test image shape:\", test_images.shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "awdv-oNMgAWr",
"outputId": "537700c0-35af-4370-f31b-c4252e0d92cf"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"train image shape: (50000, 784)\n",
"validation image shape: (10000, 784)\n",
"test image shape: (10000, 784)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"\n",
"x_test = test_images[:10]\n",
"y_test = test_labels[:10]\n",
"\n",
"rotated_images = np.array(list(map(lambda i: x_test[i].reshape(28, 28), range(len(x_test)))))\n",
"labels_mapping = [str(i) for i in range(10)]\n",
"\n",
"fig = plt.figure(0)\n",
"fig.set_size_inches(30, 30)\n",
"for i in range(len(x_test)):\n",
" fig.add_subplot(5, 5, i+1)\n",
" plt.imshow(rotated_images[i])\n",
" plt.title(f\"Label: {labels_mapping[y_test[i]]}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 566
},
"id": "UX1g2asjk3-p",
"outputId": "c6f46ac2-a885-4a73-f43e-38b5a38c8891"
},
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 3000x3000 with 10 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"from typing import Optional\n",
"\n",
"\n",
"class TwoLayerNet:\n",
" \"\"\"A two-layer fully-connected neural network.\n",
"\n",
" The net has an input dimension of N, a hidden layer dimension of H,\n",
" and performs classification over C classes.\n",
" We train the network with a softmax loss function and L2 regularization on the\n",
" weight matrices. The network uses a ReLU nonlinearity after the first fully\n",
" connected layer.\n",
"\n",
" In other words, the network has the following architecture:\n",
"\n",
" input - fully connected layer - ReLU - fully connected layer - softmax\n",
"\n",
" The outputs of the second fully-connected layer are the scores for each class.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self, input_size: int, hidden_size: int, output_size: int, std: float = 1e-4\n",
" ) -> None:\n",
" \"\"\"Initialize the model. Weights are initialized to small random values and\n",
" biases are initialized to zero. Weights and biases are stored in the\n",
" variable self.params, which is a dictionary with the following keys:\n",
"\n",
" W1: First layer weights; has shape (D, H)\n",
" b1: First layer biases; has shape (H,)\n",
" W2: Second layer weights; has shape (H, C)\n",
" b2: Second layer biases; has shape (C,)\n",
"\n",
" Inputs:\n",
" - input_size: The dimension D of the input data.\n",
" - hidden_size: The number of neurons H in the hidden layer.\n",
" - output_size: The number of classes C.\n",
" \"\"\"\n",
" self.params = {}\n",
" self.params[\"W1\"] = std * np.random.randn(input_size, hidden_size)\n",
" self.params[\"b1\"] = np.zeros(hidden_size)\n",
" self.params[\"W2\"] = std * np.random.randn(hidden_size, output_size)\n",
" self.params[\"b2\"] = np.zeros(output_size)\n",
"\n",
" def loss(\n",
" self, X: np.ndarray, y: Optional[np.ndarray] = None, reg: float = 0.\n",
" ) -> tuple[float, dict[str, float]]:\n",
" \"\"\"Compute the loss and gradients for a two layer fully connected neural network.\n",
"\n",
" Inputs:\n",
" - X: Input data of shape (N, D). Each X[i] is a training sample.\n",
" - y: Vector of training labels. y[i] is the label for X[i], and each y[i] is\n",
" an integer in the range 0 <= y[i] < C. This parameter is optional; if it\n",
" is not passed then we only return scores, and if it is passed then we\n",
" instead return the loss and gradients.\n",
" - reg: Regularization strength.\n",
"\n",
" Returns:\n",
" If y is None, return a matrix scores of shape (N, C) where scores[i, c] is\n",
" the score for class c on input X[i].\n",
"\n",
" If y is not None, instead return a tuple of:\n",
" - loss: Loss (data loss and regularization loss) for this batch of training\n",
" samples.\n",
" - grads: Dictionary mapping parameter names to gradients of those parameters\n",
" with respect to the loss function; has the same keys as self.params.\n",
" \"\"\"\n",
" # Unpack variables from the params dictionary\n",
" W1, b1 = self.params[\"W1\"], self.params[\"b1\"]\n",
" W2, b2 = self.params[\"W2\"], self.params[\"b2\"]\n",
" N, D = X.shape\n",
"\n",
" # Compute the forward pass\n",
" l1 = np.maximum(0, X.dot(W1) + b1) # ReLU. l1: (N, hidden_size)\n",
" scores = l1.dot(W2) + b2 # (N, output_size)\n",
"\n",
" # If the targets are not given then jump out, we're done\n",
" if y is None:\n",
" return scores\n",
"\n",
" # Compute the loss\n",
" scores -= np.max(scores, axis=1, keepdims=True)\n",
" p = np.exp(scores) / np.exp(scores).sum(axis=1, keepdims=True)\n",
" loss = -np.log(p[np.arange(N), y]).sum()\n",
" loss /= N\n",
" loss += 0.5 * reg * (np.sum(W1 * W1) + np.sum(W2 * W2))\n",
"\n",
" # Backward pass: compute gradients\n",
" grads: dict[str, float] = {}\n",
" # For W2 and b2\n",
" dsoftmax_loss = p\n",
" dsoftmax_loss[np.arange(N), y] -= 1\n",
" dsoftmax_loss /= N\n",
" grads[\"W2\"] = l1.T.dot(dsoftmax_loss)\n",
" grads[\"b2\"] = dsoftmax_loss.sum(axis=0)\n",
"\n",
" # For hidden layer\n",
" dl1 = dsoftmax_loss.dot(W2.T)\n",
"\n",
" # For ReLU\n",
" dl1[l1==0] = 0\n",
"\n",
" # For W1 and b1\n",
" grads[\"W1\"] = X.T.dot(dl1)\n",
" grads[\"b1\"] = dl1.sum(axis=0)\n",
"\n",
" # Regularization\n",
" grads[\"W2\"] += reg * W2\n",
" grads[\"W1\"] += reg * W1\n",
"\n",
" return loss, grads\n",
"\n",
" def train(\n",
" self,\n",
" X: np.ndarray,\n",
" y: np.ndarray,\n",
" X_val: np.ndarray,\n",
" y_val: np.ndarray,\n",
" learning_rate: float = 1e-3,\n",
" learning_rate_decay: float = 0.95,\n",
" reg: float = 5e-6,\n",
" num_iters: int = 100,\n",
" batch_size: int = 200,\n",
" verbose: bool = False\n",
" ) -> dict[str, list[float]]:\n",
" \"\"\"Train this neural network using stochastic gradient descent.\n",
"\n",
" Inputs:\n",
" - X: A numpy array of shape (N, D) giving training data.\n",
" - y: A numpy array f shape (N,) giving training labels; y[i] = c means that\n",
" X[i] has label c, where 0 <= c < C.\n",
" - X_val: A numpy array of shape (N_val, D) giving validation data.\n",
" - y_val: A numpy array of shape (N_val,) giving validation labels.\n",
" - learning_rate: Scalar giving learning rate for optimization.\n",
" - learning_rate_decay: Scalar giving factor used to decay the learning rate\n",
" after each epoch.\n",
" - reg: Scalar giving regularization strength.\n",
" - num_iters: Number of steps to take when optimizing.\n",
" - batch_size: Number of training examples to use per step.\n",
" - verbose: boolean; if true print progress during optimization.\n",
" \"\"\"\n",
" num_train = X.shape[0]\n",
" iterations_per_epoch = max(num_train / batch_size, 1)\n",
"\n",
" # Use SGD to optimize the parameters in self.model\n",
" loss_history = []\n",
" train_acc_history = []\n",
" val_acc_history = []\n",
"\n",
" for it in range(num_iters):\n",
" sampling_indices = np.random.choice(num_train, batch_size)\n",
" X_batch = X[sampling_indices,:]\n",
" y_batch = y[sampling_indices]\n",
"\n",
" # Compute loss and gradients using the current minibatch\n",
" loss, grads = self.loss(X_batch, y=y_batch, reg=reg)\n",
" loss_history.append(loss)\n",
"\n",
" self.params[\"W2\"] -= learning_rate*grads[\"W2\"]\n",
" self.params[\"b2\"] -= learning_rate*grads[\"b2\"]\n",
" self.params[\"W1\"] -= learning_rate*grads[\"W1\"]\n",
" self.params[\"b1\"] -= learning_rate*grads[\"b1\"]\n",
"\n",
" if verbose and it % 100 == 0 or it + 1 == num_iters:\n",
" print(f\"iteration {it} / {num_iters}: loss {loss}\")\n",
"\n",
" # Every epoch, check train and val accuracy and decay learning rate.\n",
" if it % iterations_per_epoch == 0:\n",
" # Check accuracy\n",
" train_acc = (self.predict(X_batch) == y_batch).mean()\n",
" val_acc = (self.predict(X_val) == y_val).mean()\n",
" train_acc_history.append(train_acc)\n",
" val_acc_history.append(val_acc)\n",
"\n",
" # Decay learning rate\n",
" learning_rate *= learning_rate_decay\n",
"\n",
" return {\n",
" \"loss_history\": loss_history,\n",
" \"train_acc_history\": train_acc_history,\n",
" \"val_acc_history\": val_acc_history,\n",
" }\n",
"\n",
" def predict(self, X: np.ndarray) -> int:\n",
" \"\"\"Use the trained weights of this two-layer network to predict labels for\n",
" data points. For each data point we predict scores for each of the C\n",
" classes, and assign each data point to the class with the highest score.\n",
"\n",
" Inputs:\n",
" - X: A numpy array of shape (N, D) giving N D-dimensional data points to\n",
" classify.\n",
"\n",
" Returns:\n",
" - y_pred: A numpy array of shape (N,) giving predicted labels for each of\n",
" the elements of X. For all i, y_pred[i] = c means that X[i] is predicted\n",
" to have class c, where 0 <= c < C.\n",
" \"\"\"\n",
" scores = self.loss(X)\n",
" y_pred = np.argmax(scores, axis=1)\n",
" return y_pred"
],
"metadata": {
"id": "nBuuVZgTgGQW"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"net = TwoLayerNet(input_size=28*28, hidden_size=50, output_size=10, std=1e-1)\n",
"stats = net.train(\n",
" train_images,\n",
" train_labels,\n",
" validation_images,\n",
" validation_labels,\n",
" num_iters=50000,\n",
" batch_size=32,\n",
" learning_rate=1e-4,\n",
" learning_rate_decay=0.95,\n",
" reg=0.25,\n",
" verbose=True\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x0FUH6uCgU4q",
"outputId": "b1a987f3-c9ab-4b23-8b09-1247b9d8f42a"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"iteration 0 / 50000: loss 174.53135037310625\n",
"iteration 100 / 50000: loss 69.48747236188227\n",
"iteration 200 / 50000: loss 56.49037147912324\n",
"iteration 300 / 50000: loss 55.089730076216675\n",
"iteration 400 / 50000: loss 53.36511864384791\n",
"iteration 500 / 50000: loss 53.81747712002853\n",
"iteration 600 / 50000: loss 54.47308871264755\n",
"iteration 700 / 50000: loss 48.857464900515225\n",
"iteration 800 / 50000: loss 50.368133951829776\n",
"iteration 900 / 50000: loss 51.23669198011844\n",
"iteration 1000 / 50000: loss 51.14886770388517\n",
"iteration 1100 / 50000: loss 54.9678566815431\n",
"iteration 1200 / 50000: loss 48.18517976759856\n",
"iteration 1300 / 50000: loss 49.9726430580723\n",
"iteration 1400 / 50000: loss 49.18964790207137\n",
"iteration 1500 / 50000: loss 46.99336147176874\n",
"iteration 1600 / 50000: loss 45.891550518507785\n",
"iteration 1700 / 50000: loss 48.61488855707326\n",
"iteration 1800 / 50000: loss 46.10700031317256\n",
"iteration 1900 / 50000: loss 46.95774963008768\n",
"iteration 2000 / 50000: loss 46.428578383379445\n",
"iteration 2100 / 50000: loss 45.63466912068143\n",
"iteration 2200 / 50000: loss 46.53266795466356\n",
"iteration 2300 / 50000: loss 44.94858555355828\n",
"iteration 2400 / 50000: loss 45.60557814755142\n",
"iteration 2500 / 50000: loss 45.39935230009701\n",
"iteration 2600 / 50000: loss 44.67134675483038\n",
"iteration 2700 / 50000: loss 44.77288920410572\n",
"iteration 2800 / 50000: loss 43.69977332688362\n",
"iteration 2900 / 50000: loss 43.730446656985066\n",
"iteration 3000 / 50000: loss 43.560656333147215\n",
"iteration 3100 / 50000: loss 44.420026111542064\n",
"iteration 3200 / 50000: loss 43.27341088354223\n",
"iteration 3300 / 50000: loss 42.6917082043561\n",
"iteration 3400 / 50000: loss 42.99590946087199\n",
"iteration 3500 / 50000: loss 42.074454274905015\n",
"iteration 3600 / 50000: loss 42.34553108107577\n",
"iteration 3700 / 50000: loss 42.19451871876929\n",
"iteration 3800 / 50000: loss 42.30068082867391\n",
"iteration 3900 / 50000: loss 41.59060395331328\n",
"iteration 4000 / 50000: loss 41.837978585274996\n",
"iteration 4100 / 50000: loss 41.661003320080354\n",
"iteration 4200 / 50000: loss 41.56867064092525\n",
"iteration 4300 / 50000: loss 42.265296452640214\n",
"iteration 4400 / 50000: loss 40.74161597772827\n",
"iteration 4500 / 50000: loss 40.686929046416815\n",
"iteration 4600 / 50000: loss 41.38809501428429\n",
"iteration 4700 / 50000: loss 39.92590750434553\n",
"iteration 4800 / 50000: loss 39.88043198325949\n",
"iteration 4900 / 50000: loss 40.50956039318266\n",
"iteration 5000 / 50000: loss 39.84197097002049\n",
"iteration 5100 / 50000: loss 39.64462218143873\n",
"iteration 5200 / 50000: loss 39.41114249441797\n",
"iteration 5300 / 50000: loss 38.5653154055042\n",
"iteration 5400 / 50000: loss 38.69729347046285\n",
"iteration 5500 / 50000: loss 38.686656362672764\n",
"iteration 5600 / 50000: loss 38.49321014870525\n",
"iteration 5700 / 50000: loss 38.337155326339015\n",
"iteration 5800 / 50000: loss 38.9238241645377\n",
"iteration 5900 / 50000: loss 38.184190026416196\n",
"iteration 6000 / 50000: loss 38.17614867043852\n",
"iteration 6100 / 50000: loss 37.3392965852913\n",
"iteration 6200 / 50000: loss 37.11046340680676\n",
"iteration 6300 / 50000: loss 36.88536839567879\n",
"iteration 6400 / 50000: loss 37.036160886820696\n",
"iteration 6500 / 50000: loss 37.49150654153678\n",
"iteration 6600 / 50000: loss 36.93653620441786\n",
"iteration 6700 / 50000: loss 36.591272291210004\n",
"iteration 6800 / 50000: loss 36.450949192373\n",
"iteration 6900 / 50000: loss 36.81271931330199\n",
"iteration 7000 / 50000: loss 36.29706613913226\n",
"iteration 7100 / 50000: loss 36.09016551195607\n",
"iteration 7200 / 50000: loss 35.6288939025018\n",
"iteration 7300 / 50000: loss 35.84212875651818\n",
"iteration 7400 / 50000: loss 35.841453475823656\n",
"iteration 7500 / 50000: loss 35.18575208018029\n",
"iteration 7600 / 50000: loss 35.44639979859994\n",
"iteration 7700 / 50000: loss 34.59538986968303\n",
"iteration 7800 / 50000: loss 34.73558738083211\n",
"iteration 7900 / 50000: loss 34.684809641588565\n",
"iteration 8000 / 50000: loss 34.48921078680681\n",
"iteration 8100 / 50000: loss 35.072609939174036\n",
"iteration 8200 / 50000: loss 35.38801885322622\n",
"iteration 8300 / 50000: loss 34.31061782509506\n",
"iteration 8400 / 50000: loss 33.65215742062905\n",
"iteration 8500 / 50000: loss 33.72227128692343\n",
"iteration 8600 / 50000: loss 33.76233694603763\n",
"iteration 8700 / 50000: loss 33.64317693187531\n",
"iteration 8800 / 50000: loss 33.357455512988174\n",
"iteration 8900 / 50000: loss 33.221809735582895\n",
"iteration 9000 / 50000: loss 32.87529874204728\n",
"iteration 9100 / 50000: loss 32.88536365723364\n",
"iteration 9200 / 50000: loss 32.543740138042175\n",
"iteration 9300 / 50000: loss 32.56281179596056\n",
"iteration 9400 / 50000: loss 32.72395247376798\n",
"iteration 9500 / 50000: loss 32.61088082207595\n",
"iteration 9600 / 50000: loss 32.143058555372114\n",
"iteration 9700 / 50000: loss 32.80067113135746\n",
"iteration 9800 / 50000: loss 32.48957475185233\n",
"iteration 9900 / 50000: loss 31.78440727988436\n",
"iteration 10000 / 50000: loss 31.526016528653052\n",
"iteration 10100 / 50000: loss 31.78920176731781\n",
"iteration 10200 / 50000: loss 31.60032011274851\n",
"iteration 10300 / 50000: loss 31.133023480494856\n",
"iteration 10400 / 50000: loss 31.245463720300588\n",
"iteration 10500 / 50000: loss 31.079334811258377\n",
"iteration 10600 / 50000: loss 30.875925471138935\n",
"iteration 10700 / 50000: loss 30.441750583147844\n",
"iteration 10800 / 50000: loss 30.703381630755775\n",
"iteration 10900 / 50000: loss 30.57814908482637\n",
"iteration 11000 / 50000: loss 30.858829863699242\n",
"iteration 11100 / 50000: loss 30.291443623972512\n",
"iteration 11200 / 50000: loss 30.206963118117297\n",
"iteration 11300 / 50000: loss 30.27649842696651\n",
"iteration 11400 / 50000: loss 30.204706967800494\n",
"iteration 11500 / 50000: loss 29.9142981599229\n",
"iteration 11600 / 50000: loss 29.644663854494137\n",
"iteration 11700 / 50000: loss 29.680160708757132\n",
"iteration 11800 / 50000: loss 29.67622285676191\n",
"iteration 11900 / 50000: loss 29.33334949728313\n",
"iteration 12000 / 50000: loss 29.402814493866746\n",
"iteration 12100 / 50000: loss 29.208229129404774\n",
"iteration 12200 / 50000: loss 28.83510901797269\n",
"iteration 12300 / 50000: loss 28.84027846823807\n",
"iteration 12400 / 50000: loss 28.622786906209154\n",
"iteration 12500 / 50000: loss 28.575478569291846\n",
"iteration 12600 / 50000: loss 28.916354680943677\n",
"iteration 12700 / 50000: loss 28.24033424862754\n",
"iteration 12800 / 50000: loss 28.31824555972571\n",
"iteration 12900 / 50000: loss 28.847874705888767\n",
"iteration 13000 / 50000: loss 28.068666364730205\n",
"iteration 13100 / 50000: loss 27.944238954665025\n",
"iteration 13200 / 50000: loss 27.9484474395685\n",
"iteration 13300 / 50000: loss 27.663986617309074\n",
"iteration 13400 / 50000: loss 28.199930340778288\n",
"iteration 13500 / 50000: loss 27.53568960020903\n",
"iteration 13600 / 50000: loss 27.363529670383517\n",
"iteration 13700 / 50000: loss 27.415858374826534\n",
"iteration 13800 / 50000: loss 27.100141829910413\n",
"iteration 13900 / 50000: loss 27.17122090922195\n",
"iteration 14000 / 50000: loss 26.826437168795408\n",
"iteration 14100 / 50000: loss 27.229019671922106\n",
"iteration 14200 / 50000: loss 26.95957634242687\n",
"iteration 14300 / 50000: loss 26.5642566278667\n",
"iteration 14400 / 50000: loss 26.50204041127803\n",
"iteration 14500 / 50000: loss 26.43840406491342\n",
"iteration 14600 / 50000: loss 26.468097916826938\n",
"iteration 14700 / 50000: loss 26.30407585882863\n",
"iteration 14800 / 50000: loss 25.974956510190715\n",
"iteration 14900 / 50000: loss 26.37115220297027\n",
"iteration 15000 / 50000: loss 26.25095832548515\n",
"iteration 15100 / 50000: loss 25.940151536085168\n",
"iteration 15200 / 50000: loss 25.691713870249266\n",
"iteration 15300 / 50000: loss 25.442519242122653\n",
"iteration 15400 / 50000: loss 26.096506551853054\n",
"iteration 15500 / 50000: loss 25.37812424066362\n",
"iteration 15600 / 50000: loss 25.620272871604293\n",
"iteration 15700 / 50000: loss 25.377123575147184\n",
"iteration 15800 / 50000: loss 25.26732941541315\n",
"iteration 15900 / 50000: loss 25.299466727875412\n",
"iteration 16000 / 50000: loss 25.550157332625506\n",
"iteration 16100 / 50000: loss 25.050930908220298\n",
"iteration 16200 / 50000: loss 24.92453455726036\n",
"iteration 16300 / 50000: loss 24.559535345044697\n",
"iteration 16400 / 50000: loss 24.64625098019252\n",
"iteration 16500 / 50000: loss 24.50257314353376\n",
"iteration 16600 / 50000: loss 24.763459169587247\n",
"iteration 16700 / 50000: loss 24.42004009227719\n",
"iteration 16800 / 50000: loss 24.528082614956286\n",
"iteration 16900 / 50000: loss 24.333323735883404\n",
"iteration 17000 / 50000: loss 23.999346704756945\n",
"iteration 17100 / 50000: loss 24.19406909617855\n",
"iteration 17200 / 50000: loss 23.875362185362857\n",
"iteration 17300 / 50000: loss 23.912330338060517\n",
"iteration 17400 / 50000: loss 23.564423734427567\n",
"iteration 17500 / 50000: loss 24.378111624505866\n",
"iteration 17600 / 50000: loss 23.363505372846266\n",
"iteration 17700 / 50000: loss 23.619977744128594\n",
"iteration 17800 / 50000: loss 23.381916206626023\n",
"iteration 17900 / 50000: loss 23.341791613730418\n",
"iteration 18000 / 50000: loss 23.441877505571423\n",
"iteration 18100 / 50000: loss 22.93485761379101\n",
"iteration 18200 / 50000: loss 23.19392494522184\n",
"iteration 18300 / 50000: loss 23.49375124367761\n",
"iteration 18400 / 50000: loss 22.77725146402904\n",
"iteration 18500 / 50000: loss 22.888552238044646\n",
"iteration 18600 / 50000: loss 22.600348480911457\n",
"iteration 18700 / 50000: loss 22.723411469130752\n",
"iteration 18800 / 50000: loss 22.522981465449796\n",
"iteration 18900 / 50000: loss 22.76708625723702\n",
"iteration 19000 / 50000: loss 22.293490803015153\n",
"iteration 19100 / 50000: loss 22.373873778074675\n",
"iteration 19200 / 50000: loss 22.09919847409843\n",
"iteration 19300 / 50000: loss 22.021546743061922\n",
"iteration 19400 / 50000: loss 22.288039981968744\n",
"iteration 19500 / 50000: loss 21.912841482273592\n",
"iteration 19600 / 50000: loss 22.122216523254327\n",
"iteration 19700 / 50000: loss 21.79926823885294\n",
"iteration 19800 / 50000: loss 21.973400439939418\n",
"iteration 19900 / 50000: loss 21.693223067974696\n",
"iteration 20000 / 50000: loss 21.5357618299404\n",
"iteration 20100 / 50000: loss 21.733154764682837\n",
"iteration 20200 / 50000: loss 21.515381833778605\n",
"iteration 20300 / 50000: loss 21.288954062052184\n",
"iteration 20400 / 50000: loss 21.474196463816565\n",
"iteration 20500 / 50000: loss 21.396617104101\n",
"iteration 20600 / 50000: loss 21.225382441901214\n",
"iteration 20700 / 50000: loss 21.040838895380908\n",
"iteration 20800 / 50000: loss 20.98119561252095\n",
"iteration 20900 / 50000: loss 20.885136500129207\n",
"iteration 21000 / 50000: loss 21.29386793015197\n",
"iteration 21100 / 50000: loss 20.65337077392958\n",
"iteration 21200 / 50000: loss 20.700829941914943\n",
"iteration 21300 / 50000: loss 21.089748687005883\n",
"iteration 21400 / 50000: loss 20.48501583364378\n",
"iteration 21500 / 50000: loss 20.52728917744923\n",
"iteration 21600 / 50000: loss 20.43997813081168\n",
"iteration 21700 / 50000: loss 20.785447885988603\n",
"iteration 21800 / 50000: loss 20.175151700753293\n",
"iteration 21900 / 50000: loss 20.414612805176006\n",
"iteration 22000 / 50000: loss 20.721257273452814\n",
"iteration 22100 / 50000: loss 20.25864479198273\n",
"iteration 22200 / 50000: loss 20.236838210501766\n",
"iteration 22300 / 50000: loss 20.059261341080653\n",
"iteration 22400 / 50000: loss 20.06079076472855\n",
"iteration 22500 / 50000: loss 20.24745556718755\n",
"iteration 22600 / 50000: loss 19.861843953450208\n",
"iteration 22700 / 50000: loss 20.159227469424497\n",
"iteration 22800 / 50000: loss 19.7930628797671\n",
"iteration 22900 / 50000: loss 19.515848930969728\n",
"iteration 23000 / 50000: loss 19.55254930783799\n",
"iteration 23100 / 50000: loss 19.644214714939633\n",
"iteration 23200 / 50000: loss 19.3388394854959\n",
"iteration 23300 / 50000: loss 19.258019350816735\n",
"iteration 23400 / 50000: loss 20.030911560423025\n",
"iteration 23500 / 50000: loss 19.37825154872941\n",
"iteration 23600 / 50000: loss 19.35676823094238\n",
"iteration 23700 / 50000: loss 19.11747955734764\n",
"iteration 23800 / 50000: loss 19.155303748096692\n",
"iteration 23900 / 50000: loss 18.887526484324404\n",
"iteration 24000 / 50000: loss 18.80173517490366\n",
"iteration 24100 / 50000: loss 19.200334383064234\n",
"iteration 24200 / 50000: loss 18.64276469035253\n",
"iteration 24300 / 50000: loss 18.537646256312225\n",
"iteration 24400 / 50000: loss 18.7812033707665\n",
"iteration 24500 / 50000: loss 19.05305726168593\n",
"iteration 24600 / 50000: loss 18.510650187762455\n",
"iteration 24700 / 50000: loss 18.346333107487148\n",
"iteration 24800 / 50000: loss 18.189578058220917\n",
"iteration 24900 / 50000: loss 18.37889213998917\n",
"iteration 25000 / 50000: loss 18.694072270445176\n",
"iteration 25100 / 50000: loss 18.11415416484163\n",
"iteration 25200 / 50000: loss 18.116118265087437\n",
"iteration 25300 / 50000: loss 17.96523290765853\n",
"iteration 25400 / 50000: loss 18.238540921577496\n",
"iteration 25500 / 50000: loss 18.01414919399693\n",
"iteration 25600 / 50000: loss 18.081026719715062\n",
"iteration 25700 / 50000: loss 17.859681079908594\n",
"iteration 25800 / 50000: loss 18.16374696349705\n",
"iteration 25900 / 50000: loss 17.81782234277618\n",
"iteration 26000 / 50000: loss 17.838217287279264\n",
"iteration 26100 / 50000: loss 17.592752282509927\n",
"iteration 26200 / 50000: loss 17.54736623407586\n",
"iteration 26300 / 50000: loss 17.457736458720017\n",
"iteration 26400 / 50000: loss 17.641076583693543\n",
"iteration 26500 / 50000: loss 17.41586135263364\n",
"iteration 26600 / 50000: loss 17.805173335471395\n",
"iteration 26700 / 50000: loss 17.381677060894692\n",
"iteration 26800 / 50000: loss 17.502249245041902\n",
"iteration 26900 / 50000: loss 17.48867751796187\n",
"iteration 27000 / 50000: loss 17.21685228651793\n",
"iteration 27100 / 50000: loss 17.017113820573332\n",
"iteration 27200 / 50000: loss 16.979267837975094\n",
"iteration 27300 / 50000: loss 17.276139448714048\n",
"iteration 27400 / 50000: loss 16.930110879903324\n",
"iteration 27500 / 50000: loss 16.946028854770308\n",
"iteration 27600 / 50000: loss 16.76155758305447\n",
"iteration 27700 / 50000: loss 16.703881884897925\n",
"iteration 27800 / 50000: loss 16.573052796258153\n",
"iteration 27900 / 50000: loss 16.709117087560024\n",
"iteration 28000 / 50000: loss 16.56640596623367\n",
"iteration 28100 / 50000: loss 16.45867944179294\n",
"iteration 28200 / 50000: loss 16.778757344718322\n",
"iteration 28300 / 50000: loss 16.58289811991296\n",
"iteration 28400 / 50000: loss 16.33565660832728\n",
"iteration 28500 / 50000: loss 16.695583116295715\n",
"iteration 28600 / 50000: loss 16.154086727988382\n",
"iteration 28700 / 50000: loss 16.380319476931493\n",
"iteration 28800 / 50000: loss 16.275732942536013\n",
"iteration 28900 / 50000: loss 16.184171971832182\n",
"iteration 29000 / 50000: loss 16.375291494031543\n",
"iteration 29100 / 50000: loss 16.763224751049083\n",
"iteration 29200 / 50000: loss 16.021672561212622\n",
"iteration 29300 / 50000: loss 15.795271512826254\n",
"iteration 29400 / 50000: loss 15.87836620709595\n",
"iteration 29500 / 50000: loss 15.805290496503073\n",
"iteration 29600 / 50000: loss 16.152529689220675\n",
"iteration 29700 / 50000: loss 15.944506918981642\n",
"iteration 29800 / 50000: loss 15.893673063906528\n",
"iteration 29900 / 50000: loss 15.528772865453309\n",
"iteration 30000 / 50000: loss 16.756379930359024\n",
"iteration 30100 / 50000: loss 15.467145021070547\n",
"iteration 30200 / 50000: loss 15.55196550500803\n",
"iteration 30300 / 50000: loss 15.481967712893997\n",
"iteration 30400 / 50000: loss 15.276883805723196\n",
"iteration 30500 / 50000: loss 15.3796924120657\n",
"iteration 30600 / 50000: loss 15.279912361697642\n",
"iteration 30700 / 50000: loss 15.648586935529044\n",
"iteration 30800 / 50000: loss 15.444319767909679\n",
"iteration 30900 / 50000: loss 15.201917758104246\n",
"iteration 31000 / 50000: loss 15.06743264942836\n",
"iteration 31100 / 50000: loss 15.160527563445566\n",
"iteration 31200 / 50000: loss 15.161275398258155\n",
"iteration 31300 / 50000: loss 15.13289950807995\n",
"iteration 31400 / 50000: loss 15.06853577758919\n",
"iteration 31500 / 50000: loss 15.260713427422026\n",
"iteration 31600 / 50000: loss 15.492774260437615\n",
"iteration 31700 / 50000: loss 14.850658746492925\n",
"iteration 31800 / 50000: loss 15.11631322443573\n",
"iteration 31900 / 50000: loss 14.736472966730114\n",
"iteration 32000 / 50000: loss 15.099770723070765\n",
"iteration 32100 / 50000: loss 15.111730671015495\n",
"iteration 32200 / 50000: loss 14.693458659585392\n",
"iteration 32300 / 50000: loss 14.666989430104637\n",
"iteration 32400 / 50000: loss 14.5936762169567\n",
"iteration 32500 / 50000: loss 14.868163403346193\n",
"iteration 32600 / 50000: loss 14.660362100106335\n",
"iteration 32700 / 50000: loss 14.646755165645112\n",
"iteration 32800 / 50000: loss 14.909298460220077\n",
"iteration 32900 / 50000: loss 14.645707171664155\n",
"iteration 33000 / 50000: loss 14.24298172668276\n",
"iteration 33100 / 50000: loss 14.293340229599098\n",
"iteration 33200 / 50000: loss 14.120379417409692\n",
"iteration 33300 / 50000: loss 14.596219655433016\n",
"iteration 33400 / 50000: loss 14.294569877063452\n",
"iteration 33500 / 50000: loss 14.30047939188503\n",
"iteration 33600 / 50000: loss 14.188473392150062\n",
"iteration 33700 / 50000: loss 14.184761455292737\n",
"iteration 33800 / 50000: loss 14.158177248830349\n",
"iteration 33900 / 50000: loss 13.900904315480991\n",
"iteration 34000 / 50000: loss 14.079951184400997\n",
"iteration 34100 / 50000: loss 14.141673338448653\n",
"iteration 34200 / 50000: loss 14.046554166060083\n",
"iteration 34300 / 50000: loss 13.778012964398286\n",
"iteration 34400 / 50000: loss 13.821220262244594\n",
"iteration 34500 / 50000: loss 13.771860419283632\n",
"iteration 34600 / 50000: loss 13.752856961008808\n",
"iteration 34700 / 50000: loss 13.533652374969042\n",
"iteration 34800 / 50000: loss 13.612305372104938\n",
"iteration 34900 / 50000: loss 13.583113347288338\n",
"iteration 35000 / 50000: loss 13.51031463943107\n",
"iteration 35100 / 50000: loss 13.575914096973545\n",
"iteration 35200 / 50000: loss 13.50956226798035\n",
"iteration 35300 / 50000: loss 13.40834186499299\n",
"iteration 35400 / 50000: loss 13.373483009719536\n",
"iteration 35500 / 50000: loss 13.695215712000191\n",
"iteration 35600 / 50000: loss 13.66697313977794\n",
"iteration 35700 / 50000: loss 13.442263041362459\n",
"iteration 35800 / 50000: loss 13.327568740000768\n",
"iteration 35900 / 50000: loss 13.387306301184324\n",
"iteration 36000 / 50000: loss 13.657648034297296\n",
"iteration 36100 / 50000: loss 13.4092552476501\n",
"iteration 36200 / 50000: loss 13.304469455836545\n",
"iteration 36300 / 50000: loss 12.966378196973036\n",
"iteration 36400 / 50000: loss 13.285886673885354\n",
"iteration 36500 / 50000: loss 13.037495893730776\n",
"iteration 36600 / 50000: loss 13.211636469875812\n",
"iteration 36700 / 50000: loss 12.87835305156886\n",
"iteration 36800 / 50000: loss 12.804866051202886\n",
"iteration 36900 / 50000: loss 13.010371116158849\n",
"iteration 37000 / 50000: loss 12.982230758218206\n",
"iteration 37100 / 50000: loss 12.884902615121534\n",
"iteration 37200 / 50000: loss 12.887174093554265\n",
"iteration 37300 / 50000: loss 12.687705287256664\n",
"iteration 37400 / 50000: loss 12.876812717149905\n",
"iteration 37500 / 50000: loss 12.64519754378002\n",
"iteration 37600 / 50000: loss 12.728348116742756\n",
"iteration 37700 / 50000: loss 12.50175203646306\n",
"iteration 37800 / 50000: loss 12.706259859217992\n",
"iteration 37900 / 50000: loss 12.523945493688803\n",
"iteration 38000 / 50000: loss 12.887242551092488\n",
"iteration 38100 / 50000: loss 12.651737347101651\n",
"iteration 38200 / 50000: loss 12.590085631462564\n",
"iteration 38300 / 50000: loss 12.512936360615527\n",
"iteration 38400 / 50000: loss 12.439284777579154\n",
"iteration 38500 / 50000: loss 12.559243165295204\n",
"iteration 38600 / 50000: loss 12.557331176611184\n",
"iteration 38700 / 50000: loss 12.400126595583867\n",
"iteration 38800 / 50000: loss 12.44607841529448\n",
"iteration 38900 / 50000: loss 12.44318444430552\n",
"iteration 39000 / 50000: loss 12.613034267994305\n",
"iteration 39100 / 50000: loss 12.222154788680026\n",
"iteration 39200 / 50000: loss 12.034536549977151\n",
"iteration 39300 / 50000: loss 12.280831910131164\n",
"iteration 39400 / 50000: loss 12.084445785798037\n",
"iteration 39500 / 50000: loss 12.182947591781806\n",
"iteration 39600 / 50000: loss 11.982446234491936\n",
"iteration 39700 / 50000: loss 11.987624264944149\n",
"iteration 39800 / 50000: loss 11.949112400340578\n",
"iteration 39900 / 50000: loss 12.234889112861032\n",
"iteration 40000 / 50000: loss 12.325508888815847\n",
"iteration 40100 / 50000: loss 11.86887952845391\n",
"iteration 40200 / 50000: loss 11.803253684311468\n",
"iteration 40300 / 50000: loss 11.857970626144022\n",
"iteration 40400 / 50000: loss 11.97765743542806\n",
"iteration 40500 / 50000: loss 11.912676109766462\n",
"iteration 40600 / 50000: loss 11.728872226326937\n",
"iteration 40700 / 50000: loss 11.825321781377474\n",
"iteration 40800 / 50000: loss 11.715736692380377\n",
"iteration 40900 / 50000: loss 11.63228653121234\n",
"iteration 41000 / 50000: loss 11.588617686418484\n",
"iteration 41100 / 50000: loss 11.863518965085886\n",
"iteration 41200 / 50000: loss 11.886012511976189\n",
"iteration 41300 / 50000: loss 11.720163653388825\n",
"iteration 41400 / 50000: loss 11.571585463372982\n",
"iteration 41500 / 50000: loss 11.48496097718125\n",
"iteration 41600 / 50000: loss 11.54391417289108\n",
"iteration 41700 / 50000: loss 11.452147479129307\n",
"iteration 41800 / 50000: loss 11.521430544390183\n",
"iteration 41900 / 50000: loss 11.274533921193326\n",
"iteration 42000 / 50000: loss 11.37100264491573\n",
"iteration 42100 / 50000: loss 11.59258718960164\n",
"iteration 42200 / 50000: loss 11.259954198743984\n",
"iteration 42300 / 50000: loss 11.28144751925919\n",
"iteration 42400 / 50000: loss 11.312711697614738\n",
"iteration 42500 / 50000: loss 11.142592027452736\n",
"iteration 42600 / 50000: loss 11.098667816440797\n",
"iteration 42700 / 50000: loss 11.257872500665412\n",
"iteration 42800 / 50000: loss 11.415369266497562\n",
"iteration 42900 / 50000: loss 11.150721463401233\n",
"iteration 43000 / 50000: loss 11.26987531703047\n",
"iteration 43100 / 50000: loss 10.988064787517237\n",
"iteration 43200 / 50000: loss 10.950464791967285\n",
"iteration 43300 / 50000: loss 11.172560229984843\n",
"iteration 43400 / 50000: loss 11.061127738808873\n",
"iteration 43500 / 50000: loss 10.951866436782726\n",
"iteration 43600 / 50000: loss 10.853426570062815\n",
"iteration 43700 / 50000: loss 10.816874501203216\n",
"iteration 43800 / 50000: loss 11.19049237175604\n",
"iteration 43900 / 50000: loss 10.822360825638848\n",
"iteration 44000 / 50000: loss 11.10811506360901\n",
"iteration 44100 / 50000: loss 10.799356117751772\n",
"iteration 44200 / 50000: loss 10.91294513651592\n",
"iteration 44300 / 50000: loss 10.680198773004307\n",
"iteration 44400 / 50000: loss 10.660830927726312\n",
"iteration 44500 / 50000: loss 10.919051300885332\n",
"iteration 44600 / 50000: loss 10.875928614776988\n",
"iteration 44700 / 50000: loss 10.506314600343849\n",
"iteration 44800 / 50000: loss 10.646036351796594\n",
"iteration 44900 / 50000: loss 10.604054863181824\n",
"iteration 45000 / 50000: loss 10.443009073579772\n",
"iteration 45100 / 50000: loss 10.80263168389864\n",
"iteration 45200 / 50000: loss 10.483164796743166\n",
"iteration 45300 / 50000: loss 10.94555926101149\n",
"iteration 45400 / 50000: loss 10.38574934972442\n",
"iteration 45500 / 50000: loss 10.410069160032558\n",
"iteration 45600 / 50000: loss 10.647413503241273\n",
"iteration 45700 / 50000: loss 10.368145965990134\n",
"iteration 45800 / 50000: loss 10.283998056989505\n",
"iteration 45900 / 50000: loss 10.597760335818439\n",
"iteration 46000 / 50000: loss 10.421992530001722\n",
"iteration 46100 / 50000: loss 10.529896100274234\n",
"iteration 46200 / 50000: loss 10.291705374402628\n",
"iteration 46300 / 50000: loss 10.348052447820118\n",
"iteration 46400 / 50000: loss 10.41952728017204\n",
"iteration 46500 / 50000: loss 10.24739255808342\n",
"iteration 46600 / 50000: loss 10.237543234527113\n",
"iteration 46700 / 50000: loss 10.377697013557263\n",
"iteration 46800 / 50000: loss 10.074917156907643\n",
"iteration 46900 / 50000: loss 10.028152668869403\n",
"iteration 47000 / 50000: loss 10.50884267683656\n",
"iteration 47100 / 50000: loss 9.985853745203897\n",
"iteration 47200 / 50000: loss 10.157402291954302\n",
"iteration 47300 / 50000: loss 10.31320363457757\n",
"iteration 47400 / 50000: loss 10.076072474852781\n",
"iteration 47500 / 50000: loss 10.158361756895623\n",
"iteration 47600 / 50000: loss 10.34663936996847\n",
"iteration 47700 / 50000: loss 9.866624717206411\n",
"iteration 47800 / 50000: loss 9.844635672517189\n",
"iteration 47900 / 50000: loss 9.928906689366547\n",
"iteration 48000 / 50000: loss 9.951519816382103\n",
"iteration 48100 / 50000: loss 9.773064510465177\n",
"iteration 48200 / 50000: loss 9.832071061520283\n",
"iteration 48300 / 50000: loss 10.054238843593746\n",
"iteration 48400 / 50000: loss 9.924989220031197\n",
"iteration 48500 / 50000: loss 9.958717780867367\n",
"iteration 48600 / 50000: loss 9.797941956852611\n",
"iteration 48700 / 50000: loss 9.649116243793362\n",
"iteration 48800 / 50000: loss 9.82644873167511\n",
"iteration 48900 / 50000: loss 9.598817224677786\n",
"iteration 49000 / 50000: loss 9.708477925028642\n",
"iteration 49100 / 50000: loss 9.764272354996994\n",
"iteration 49200 / 50000: loss 9.606586381901913\n",
"iteration 49300 / 50000: loss 10.037900593533568\n",
"iteration 49400 / 50000: loss 9.696623042912284\n",
"iteration 49500 / 50000: loss 9.73634403306151\n",
"iteration 49600 / 50000: loss 9.617579384648753\n",
"iteration 49700 / 50000: loss 9.517357451722274\n",
"iteration 49800 / 50000: loss 9.417539972454424\n",
"iteration 49900 / 50000: loss 9.673809198022692\n",
"iteration 49999 / 50000: loss 9.519560238120139\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Plot the loss function and train / validation accuracies\n",
"plt.subplot(3, 1, 1)\n",
"plt.plot(stats[\"loss_history\"])\n",
"plt.title(\"Loss history\")\n",
"plt.xlabel(\"Iteration\")\n",
"plt.ylabel(\"Loss\")\n",
"\n",
"plt.subplot(3, 1, 3)\n",
"plt.plot(stats[\"train_acc_history\"], label=\"train\")\n",
"plt.plot(stats[\"val_acc_history\"], label=\"val\")\n",
"plt.title(\"Classification accuracy history\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Clasification accuracy\")\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 472
},
"id": "geCpl7jBm9ew",
"outputId": "44a854e1-e159-48e1-9791-4fefef500348"
},
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"predictions = net.predict(test_images)\n",
"print(\"Accuracy: \", sum(predictions == test_labels) / len(predictions) * 100, \"%\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rs9WKsLCnheg",
"outputId": "5cd5586c-9af2-4796-978e-fe8c0cb8de59"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Accuracy: 91.62 %\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"sampling_indices = np.random.choice(10000, 10)\n",
"x_test = test_images[sampling_indices]\n",
"y_test = test_labels[sampling_indices]\n",
"predictions_sample = predictions[sampling_indices]\n",
"\n",
"rotated_images = np.array(list(map(lambda i: x_test[i].reshape(28, 28), range(len(x_test)))))\n",
"labels_mapping = [str(i) for i in range(10)]\n",
"\n",
"fig = plt.figure(0)\n",
"fig.set_size_inches(30, 30)\n",
"for i in range(len(x_test)):\n",
" fig.add_subplot(5, 5, i + 1)\n",
" plt.imshow(rotated_images[i])\n",
" plt.title(f\"Label: {labels_mapping[y_test[i]]} / Prediction: {labels_mapping[predictions_sample[i]]}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 566
},
"id": "aMy6TkGqoGU1",
"outputId": "23caf72a-7153-450b-a9b1-1b1b8b489378"
},
"execution_count": 9,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 3000x3000 with 10 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# Usage\n",
"https://mco-mnist-draw-rwpxka3zaa-ue.a.run.app/"
],
"metadata": {
"id": "8C7gUhiIpZZE"
},
"execution_count": 9,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment