Skip to content

Instantly share code, notes, and snippets.

@Unbinilium
Created February 28, 2022 15:10
Show Gist options
  • Save Unbinilium/4cf7acd83e9c0f3bb0833a896c905fb5 to your computer and use it in GitHub Desktop.
Save Unbinilium/4cf7acd83e9c0f3bb0833a896c905fb5 to your computer and use it in GitHub Desktop.
Create Torch Script Model from MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "aab63ec9",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms\n",
"from torch.optim.lr_scheduler import StepLR"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e898a950",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23f7c1a7",
"metadata": {},
"outputs": [],
"source": [
"train_kwargs = {'batch_size': 64}\n",
"test_kwargs = {'batch_size': 1000}\n",
"\n",
"if torch.cuda.is_available():\n",
" cuda_kwargs = {\n",
" 'num_workers': 1,\n",
" 'pin_memory' : True,\n",
" 'shuffle' : True\n",
" }\n",
" \n",
" train_kwargs.update(cuda_kwargs)\n",
" test_kwargs.update(cuda_kwargs)\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
"])\n",
"\n",
"dataset1 = datasets.MNIST(\n",
" root='./data',\n",
" train=True,\n",
" download=True,\n",
" transform=transform\n",
")\n",
"dataset2 = datasets.MNIST(\n",
" root='./data',\n",
" train=False,\n",
" download=True,\n",
" transform=transform\n",
")\n",
"\n",
"train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)\n",
"test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d9caaa6",
"metadata": {},
"outputs": [],
"source": [
"def im_convert(tensor):\n",
" image = tensor.cpu().clone().detach().numpy()\n",
" image = image.transpose(1, 2, 0)\n",
" image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))\n",
" image = image.clip(0, 1)\n",
" return image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cea475ea",
"metadata": {},
"outputs": [],
"source": [
"data_iter = iter(train_loader)\n",
"images, labels = data_iter.next()\n",
"fig = plt.figure(figsize=(25, 4))\n",
"\n",
"for idx in np.arange(20):\n",
" ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])\n",
" plt.imshow(im_convert(images[idx]))\n",
" ax.set_title([labels[idx].item()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e85a14e2",
"metadata": {},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
" self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
" self.dropout1 = nn.Dropout(0.25)\n",
" self.dropout2 = nn.Dropout(0.5)\n",
" self.fc1 = nn.Linear(9216, 128)\n",
" self.fc2 = nn.Linear(128, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = F.relu(x)\n",
" x = self.conv2(x)\n",
" x = F.relu(x)\n",
" x = F.max_pool2d(x, 2)\n",
" x = self.dropout1(x)\n",
" x = torch.flatten(x, 1)\n",
" x = self.fc1(x)\n",
" x = F.relu(x)\n",
" x = self.dropout2(x)\n",
" x = self.fc2(x)\n",
" output = F.log_softmax(x, dim=1)\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9db664f3",
"metadata": {},
"outputs": [],
"source": [
"model = Net().to(device)\n",
"optimizer = optim.Adadelta(model.parameters(), lr=1.0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba484953",
"metadata": {},
"outputs": [],
"source": [
"def train(model, device, train_loader, optimizer, epoch):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" data, target = data.to(device), target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if batch_idx % 30 == 0:\n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch,\n",
" batch_idx * len(data),\n",
" len(train_loader.dataset),\n",
" 100. * batch_idx / len(train_loader),\n",
" loss.item())\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b329dab1",
"metadata": {},
"outputs": [],
"source": [
"def test(model, device, test_loader):\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" data, target = data.to(device), target.to(device)\n",
" output = model(data)\n",
" test_loss += F.nll_loss(output, target, reduction='sum').item()\n",
" pred = output.argmax(dim=1, keepdim=True)\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
"\n",
" test_loss /= len(test_loader.dataset)\n",
" test_accuracy = 100. * correct / len(test_loader.dataset)\n",
"\n",
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss,\n",
" correct,\n",
" len(test_loader.dataset),\n",
" test_accuracy)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e43edf6f",
"metadata": {},
"outputs": [],
"source": [
"epochs = 10\n",
"scheduler = StepLR(optimizer, step_size=1, gamma=0.7)\n",
"\n",
"for epoch in range(1, epochs + 1):\n",
" train(model, device, train_loader, optimizer, epoch)\n",
" test(model, device, test_loader)\n",
" scheduler.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9cca9d30",
"metadata": {},
"outputs": [],
"source": [
"data_iter = iter(test_loader)\n",
"images, labels = data_iter.next()\n",
"images = images.to(device)\n",
"labels = labels.to(device)\n",
"output = model(images)\n",
"_, preds = torch.max(output, 1)\n",
"\n",
"fig = plt.figure(figsize=(25, 4))\n",
"\n",
"for idx in np.arange(20):\n",
" ax = fig.add_subplot(2, 10, idx + 1, xticks=[], yticks=[])\n",
" plt.imshow(im_convert(images[idx]))\n",
" ax.set_title(\"{} ({})\".format(str(preds[idx].item()), str(labels[idx].item())), color=(\"green\" if preds[idx] == labels[idx] else \"red\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0bbb9e1",
"metadata": {},
"outputs": [],
"source": [
"blob_input = torch.zeros(1, 1, 28, 28).to(device)\n",
"trained_network = model.to(device)\n",
"traced_model = torch.jit.trace(trained_network, blob_input)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8190ce4c",
"metadata": {},
"outputs": [],
"source": [
"print(traced_model.code)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1336208b",
"metadata": {},
"outputs": [],
"source": [
"traced_model.save(\"./mnist_traced.pt\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment