Created
December 15, 2020 16:28
-
-
Save nmilosev/0b6e9600bf0a4b0a16b19ce9d8da12b2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Testing tinygrad in Carnets on an iPad" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Warning: ignoring SSL_CERT_DIR environment variable, not supported by libcurl\n", | |
" % Total % Received % Xferd Average Speed Time Time Time Current\n", | |
" Dload Upload Total Spent Left Speed\n", | |
"100 99742 0 99742 0 0 175k 0 --:--:-- --:--:-- --:--:-- 175k\n", | |
"\r" | |
] | |
} | |
], | |
"source": [ | |
"!curl https://codeload.github.com/geohot/tinygrad/tar.gz/master > master.tar.gz\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"tar: Failed to set default locale\n", | |
"\r" | |
] | |
} | |
], | |
"source": [ | |
"!tar xf master.tar.gz" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!mv tinygrad-master/tinygrad ." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!mv tinygrad-master/extra ." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Defaulting to user installation because normal site-packages is not writeable\n", | |
"Requirement already satisfied: tqdm in /private/var/mobile/Containers/Data/Application/D5CAF77E-50A8-4B78-B158-D589ED88E5E8/Library/lib/python3.9/site-packages (4.54.1)\n", | |
"Requirement already satisfied: requests in /private/var/mobile/Containers/Data/Application/D5CAF77E-50A8-4B78-B158-D589ED88E5E8/Library/lib/python3.9/site-packages (2.25.0)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /private/var/mobile/Containers/Data/Application/D5CAF77E-50A8-4B78-B158-D589ED88E5E8/Library/lib/python3.9/site-packages (from requests) (3.0.4)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /private/var/mobile/Containers/Data/Application/D5CAF77E-50A8-4B78-B158-D589ED88E5E8/Library/lib/python3.9/site-packages (from requests) (2.10)\n", | |
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /private/var/containers/Bundle/Application/04AC0AE0-FF40-48AE-9289-E63AC4E1631E/Carnets.app/Library/lib/python3.9/site-packages (from requests) (1.26.2)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /private/var/containers/Bundle/Application/04AC0AE0-FF40-48AE-9289-E63AC4E1631E/Carnets.app/Library/lib/python3.9/site-packages (from requests) (2020.6.20)\n", | |
"\r" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[33mWARNING: You are using pip version 20.3; however, version 20.3.1 is available.\n", | |
"You should consider upgrading via the '/private/var/containers/Bundle/Application/04AC0AE0-FF40-48AE-9289-E63AC4E1631E/Carnets.app/Library/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n", | |
"\r" | |
] | |
} | |
], | |
"source": [ | |
"!pip install tqdm requests" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tinygrad" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Tensor array([[ 2., 2., 2.],\n", | |
" [ 0., 0., 0.],\n", | |
" [-2., -2., -2.]], dtype=float32) with grad None\n", | |
"Tensor array([[1., 1., 1.]], dtype=float32) with grad None\n" | |
] | |
} | |
], | |
"source": [ | |
"from tinygrad.tensor import Tensor\n", | |
"\n", | |
"x = Tensor.eye(3)\n", | |
"y = Tensor([[2.0,0,-2.0]])\n", | |
"z = y.matmul(x).sum()\n", | |
"z.backward()\n", | |
"\n", | |
"print(x.grad) # dz/dx\n", | |
"print(y.grad) # dz/dy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import unittest\n", | |
"import numpy as np\n", | |
"from tinygrad.tensor import Tensor, GPU\n", | |
"import tinygrad.optim as optim\n", | |
"from extra.training import train, evaluate\n", | |
"from extra.utils import fetch, get_parameters\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"fetching http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", | |
"fetching http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", | |
"fetching http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", | |
"fetching http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n" | |
] | |
} | |
], | |
"source": [ | |
"# mnist loader\n", | |
"def fetch_mnist():\n", | |
" import gzip\n", | |
" parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()\n", | |
" X_train = parse(fetch(\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\"))[0x10:].reshape((-1, 28, 28))\n", | |
" Y_train = parse(fetch(\"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"))[8:]\n", | |
" X_test = parse(fetch(\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\"))[0x10:].reshape((-1, 28, 28))\n", | |
" Y_test = parse(fetch(\"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"))[8:]\n", | |
" return X_train, Y_train, X_test, Y_test\n", | |
"\n", | |
"# load the mnist dataset\n", | |
"X_train, Y_train, X_test, Y_test = fetch_mnist()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# create a model\n", | |
"class TinyBobNet:\n", | |
"\n", | |
" def __init__(self):\n", | |
" self.l1 = Tensor.uniform(784, 128)\n", | |
" self.l2 = Tensor.uniform(128, 10)\n", | |
"\n", | |
" def parameters(self):\n", | |
" return get_parameters(self)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" return x.dot(self.l1).relu().dot(self.l2).logsoftmax()\n", | |
"\n", | |
"# create a model with a conv layer\n", | |
"class TinyConvNet:\n", | |
" def __init__(self):\n", | |
" # https://keras.io/examples/vision/mnist_convnet/\n", | |
" conv = 3\n", | |
" #inter_chan, out_chan = 32, 64\n", | |
" inter_chan, out_chan = 8, 16 # for speed\n", | |
" self.c1 = Tensor.uniform(inter_chan,1,conv,conv)\n", | |
" self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)\n", | |
" self.l1 = Tensor.uniform(out_chan*5*5, 10)\n", | |
"\n", | |
" def parameters(self):\n", | |
" return get_parameters(self)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = x.reshape(shape=(-1, 1, 28, 28)) # hacks\n", | |
" x = x.conv2d(self.c1).relu().max_pool2d()\n", | |
" x = x.conv2d(self.c2).relu().max_pool2d()\n", | |
" x = x.reshape(shape=[x.shape[0], -1])\n", | |
" return x.dot(self.l1).logsoftmax()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"loss 0.06 accuracy 0.98: 100%|██████████| 1000/1000 [01:25<00:00, 11.74it/s]\n", | |
"100%|██████████| 78/78 [00:02<00:00, 37.23it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"test set accuracy is 0.962200\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9622" | |
] | |
}, | |
"execution_count": 39, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.random.seed(1337)\n", | |
"model = TinyBobNet()\n", | |
"optimizer = optim.SGD(model.parameters(), lr=0.001)\n", | |
"train(model, X_train, Y_train, optimizer, steps=1000)\n", | |
"evaluate(model, X_test, Y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"loss 0.08 accuracy 0.98: 100%|██████████| 200/200 [00:47<00:00, 4.23it/s]\n", | |
"100%|██████████| 78/78 [00:05<00:00, 14.76it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"test set accuracy is 0.974000\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.974" | |
] | |
}, | |
"execution_count": 41, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.random.seed(1337)\n", | |
"model = TinyConvNet()\n", | |
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n", | |
"train(model, X_train, Y_train, optimizer, steps=200)\n", | |
"evaluate(model, X_test, Y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.image.AxesImage at 0x127b54ca0>" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAPXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24wK3Vua25vd24sIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcv+is0xgAAAAlwSFlzAAALEwAACxMBAJqcGAAADd9JREFUeJzt3X+s1fV9x/HXS7yggDrAgUxZca2zpdu86i2uYV1p2RpL0qLLXCRpRzc3mqyautiuRrPiH0tqtrW165wZVlba+CNuirDFbBJGYpu2zCtSfgjOX1DRG7BlK7RVBO57f9yvyy3e8zmX8xvez0dyc875vs/3fN/5hhff7zmf7zkfR4QAnPpO63YDADqDsANJEHYgCcIOJEHYgSRO7+TGJnpSnKEpndwkkMrr+qneiMMeq9ZU2G1fKekrkiZI+lpE3F56/hmaoiu8qJlNAijYFBtq1ho+jbc9QdKdkj4saZ6kpbbnNfp6ANqrmffs8yU9FxEvRMQbkh6QtKQ1bQFotWbCfr6kl0Y93lst+zm2l9setD14RIeb2ByAZjQT9rE+BHjLtbcRsTIiBiJioE+TmtgcgGY0E/a9kuaMenyBpFeaawdAuzQT9ickXWT7QtsTJV0raV1r2gLQag0PvUXEUdvXS/oPjQy9rYqIHS3rDEBLNTXOHhGPSnq0Rb0AaCMulwWSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dSUzbZ3Szok6ZikoxEx0IqmALReU2GvfCAiftiC1wHQRpzGA0k0G/aQ9JjtJ20vH+sJtpfbHrQ9eESHm9wcgEY1exq/ICJesT1T0nrbuyLi8dFPiIiVklZK0tmeHk1uD0CDmjqyR8Qr1e1+SWskzW9FUwBar+Gw255i+6w370v6kKTtrWoMQGs1cxo/S9Ia22++zn0R8e8t6QpAyzUc9oh4QdIlLewFQBsx9AYkQdiBJAg7kARhB5Ig7EASrfgiDHrYsYWXFeunf35fsf6vF68r1vs8oVg/Esdq1hZsuba47oxb+4p17365WP/RR+bVrE1/pHxJyPChQ8X6yYgjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTj7ScCTJhXrhz7aX7O24guriuu+/8yfFevDxap0pM5vDw0XXuFb/fcV173sLz9RrF9yXvlYtXbu39esvecXbiiuO+ur3ynWT0Yc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZTwKHF/56sf6fd9QeT65n42tTi/XP/9UfF+t9P2t8kp+DbysfayaWLwHQX3ymfA3Bj4eP1qxNHar9PftTFUd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYeEO8tT4b7hbv+seHXXvr84mL94Io5xfq0jd9teNv1nPOOC4v1/n9+vlh/18Tyseqda/+8Zu1X/2VTcd1TUd0ju+1Vtvfb3j5q2XTb620/W91Oa2+bAJo1ntP4r0u68rhlN0vaEBEXSdpQPQbQw+qGPSIel3TguMVLJK2u7q+WdFVr2wLQao1+QDcrIoYkqbqdWeuJtpfbHrQ9eESHG9wcgGa1/dP4iFgZEQMRMdCn8g8nAmifRsO+z/ZsSapu97euJQDt0GjY10laVt1fJmlta9oB0C51x9lt3y9poaRzbe+VtELS7ZIetH2dpB9IuqadTZ7q/ufW14r1y+u8+1m86/dq1iZ85uziuhOe2lx+8Tb638tnFesrZj7Y1OvPeayp1U85dcMeEUtrlBa1uBcAbcTlskAShB1IgrADSRB2IAnCDiTBV1w74MUHfqNY33HpPxXre4+Wh+ZOu7X2lw7jqa3FddutNN30O258urjuaXWORX+0pzwgdOYj/1WsZ8ORHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJy9A/5wXnm8d1jDxfqeo+Wvqep73RtLL42jS9Izd9T+mey1v3xncd3yXpH2/M3Fxfpk5fu56BKO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsKJrw7vJY9s4bzinWd32kPJZesvG1qcX6Wd95sVg/1vCWT00c2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZO+ChF/uL9c/O2FasXzrpp8X6+7a+fqItjdv8yQ8X6x84s7ztet9JL7np+79frF+wb0cTr55P3SO77VW299vePmrZbbZftr2l+lvc3jYBNGs8p/Ffl3TlGMu/HBH91d+jrW0LQKvVDXtEPC7pQAd6AdBGzXxAd73trdVpfs3Jxmwvtz1oe/CIDjexOQDNaDTsd0l6u6R+SUOSvljriRGxMiIGImKgT+UfJwTQPg2FPSL2RcSxiBiWdLek+a1tC0CrNRR227NHPbxa0vZazwXQG+qOs9u+X9JCSefa3itphaSFtvslhaTdkj7ZvhZPfud97OVi/aOPXF2s/9s71xbr9cbp2+l9n7uhWB9e+qOatW/131dcd+bdkxvqCWOrG/aIWDrG4nva0AuANuJyWSAJwg4kQdiBJAg7kARhB5LgK64dMHzoUPkJi8r1D179Z8X6/ssb/z972s4o1s+593vF+qvfLF8Cvav/gZq1e348t7ju5B1DxfrRYhXH48gOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzn4SmLxmU7E+d02HGhnDrg9+rVgfLvyY9J3PvL+47i+99HRDPWFsHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2VE04d0X13nGk8XqnqNv1KzN+rszGugIjeLIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM6OohdWTGxq/Wue+pOatfM2bm7qtXFi6h7Zbc+xvdH2Tts7bH+6Wj7d9nrbz1a309rfLoBGjec0/qikmyLiXZJ+U9KnbM+TdLOkDRFxkaQN1WMAPapu2CNiKCI2V/cPSdop6XxJSyStrp62WtJVbeoRQAuc0Ad0tudKulTSJkmzImJIGvkPQdLMGusstz1oe/CIyvOCAWifcYfd9lRJD0m6MSIOjne9iFgZEQMRMdCnSY30CKAFxhV2230aCfq9EfFwtXif7dlVfbak/e1pEUAr1B16s21J90jaGRFfGlVaJ2mZpNur27Vt6RBtFe+9pFhfd8U/1HmF8tdUvYFBml4xnnH2BZI+Lmmb7S3Vsls0EvIHbV8n6QeSrmlLhwBaom7YI+LbklyjvKi17QBoFy6XBZIg7EAShB1IgrADSRB2IAm+4prc/vdMKdYvPL08jl6aklmSTn89TrgntAdHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH25F4/tzwOXm8c/Y4D84r1GXd/94R7QntwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnT+5jV21sav1Va3+nWJ8rxtl7BUd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUhiPPOzz5H0DUnnSRqWtDIivmL7Nkl/KunV6qm3RMSj7WoU7fHQi/3F+mdnbOtMI2i78VxUc1TSTRGx2fZZkp60vb6qfTki/rZ97QFolfHMzz4kaai6f8j2Tknnt7sxAK11Qu/Zbc+VdKmkTdWi621vtb3K9rQa6yy3PWh78IgON9ctgIaNO+y2p0p6SNKNEXFQ0l2S3i6pXyNH/i+OtV5ErIyIgYgY6NOk5jsG0JBxhd12n0aCfm9EPCxJEbEvIo5FxLCkuyXNb1+bAJpVN+y2LekeSTsj4kujls8e9bSrJW1vfXsAWmU8n8YvkPRxSdtsb6mW3SJpqe1+SSFpt6RPtqE/tFlsmF6s33LBFcX6rMFjrWwHbTSeT+O/LcljlBhTB04iXEEHJEHYgSQIO5AEYQeSIOxAEoQdSMIR5Sl7W+lsT48rvKhj2wOy2RQbdDAOjDVUzpEdyIKwA0kQdiAJwg4kQdiBJAg7kARhB5Lo6Di77Vcl7Rm16FxJP+xYAyemV3vr1b4kemtUK3t7W0T84liFjob9LRu3ByNioGsNFPRqb73al0RvjepUb5zGA0kQdiCJbod9ZZe3X9KrvfVqXxK9NaojvXX1PTuAzun2kR1AhxB2IImuhN32lbafsf2c7Zu70UMttnfb3mZ7i+3BLveyyvZ+29tHLZtue73tZ6vbMefY61Jvt9l+udp3W2wv7lJvc2xvtL3T9g7bn66Wd3XfFfrqyH7r+Ht22xMk/bek35W0V9ITkpZGxNMdbaQG27slDURE1y/AsP3bkn4i6RsR8WvVsr+WdCAibq/+o5wWEZ/rkd5uk/STbk/jXc1WNHv0NOOSrpL0CXVx3xX6+gN1YL9148g+X9JzEfFCRLwh6QFJS7rQR8+LiMclHThu8RJJq6v7qzXyj6XjavTWEyJiKCI2V/cPSXpzmvGu7rtCXx3RjbCfL+mlUY/3qrfmew9Jj9l+0vbybjczhlkRMSSN/OORNLPL/Ryv7jTenXTcNOM9s+8amf68Wd0I+1i/j9VL438LIuIySR+W9KnqdBXjM65pvDtljGnGe0Kj0583qxth3ytpzqjHF0h6pQt9jCkiXqlu90tao96binrfmzPoVrf7u9zP/+ulabzHmmZcPbDvujn9eTfC/oSki2xfaHuipGslretCH29he0r1wYlsT5H0IfXeVNTrJC2r7i+TtLaLvfycXpnGu9Y04+ryvuv69OcR0fE/SYs18on885Ju7UYPNfr6FUnfr/52dLs3Sfdr5LTuiEbOiK6TNEPSBknPVrfTe6i3b0raJmmrRoI1u0u9/ZZG3hpulbSl+lvc7X1X6Ksj+43LZYEkuIIOSIKwA0kQdiAJwg4kQdiBJAg7kARhB5L4P//OEG7udV7mAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"from matplotlib import pyplot as plt\n", | |
"plt.imshow(X_test[12])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"9" | |
] | |
}, | |
"execution_count": 36, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Y_test[12]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = Tensor(X_test[12].reshape((-1, 28*28)).astype(np.float32))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([9])" | |
] | |
}, | |
"execution_count": 40, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.argmax(model.forward(x).cpu().data, axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.0+" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment