Skip to content

Instantly share code, notes, and snippets.

@dienhoa
Last active October 15, 2021 20:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dienhoa/109c02f1bdf6bfbd6b69d4c5f4455804 to your computer and use it in GitHub Desktop.
Save dienhoa/109c02f1bdf6bfbd6b69d4c5f4455804 to your computer and use it in GitHub Desktop.
How to take advantage of fastai with your pytorch training pipeline
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "2021-02-14-Pytorchtofastai.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "6mDXFUG4sopG"
},
"source": [
"#hide_input\n",
"!pip install -Uqq fastai"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Qai8h9caUti_",
"outputId": "c839117b-91cb-4295-bfe7-036ba5040cd4"
},
"source": [
"import fastai\n",
"print(fastai.__version__)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.5.2\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KGaGdX3DtLr6"
},
"source": [
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "MP02xowOvrMs"
},
"source": [
"transform = transforms.Compose(\n",
" [transforms.ToTensor(),\n",
" transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FNHLOjJIv1uq",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "590e636d-d6c4-4d8d-caa0-3c5168e666ab"
},
"source": [
"dset_train = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
" download=True, transform=transform)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Files already downloaded and verified\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "La_ik28xv7tH",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7e621445-8c5e-483e-a722-830a5afb1346"
},
"source": [
"dset_test = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
" download=True, transform=transform)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Files already downloaded and verified\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AH9tRI9mv9il"
},
"source": [
"trainloader = torch.utils.data.DataLoader(dset_train, batch_size=4,\n",
" shuffle=True, num_workers=2)\n",
"testloader = torch.utils.data.DataLoader(dset_test, batch_size=4,\n",
" shuffle=False, num_workers=2)"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_2SI6qIvwwvN"
},
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LrXOQM5CxRYh"
},
"source": [
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(-1, 16 * 5 * 5)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x"
],
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EC9Fc6A-xU96"
},
"source": [
"net = Net()"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "54jt7ghrzOiP"
},
"source": [
"criterion = nn.CrossEntropyLoss()"
],
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XaBZiuQKza4o"
},
"source": [
"from fastai.optimizer import OptimWrapper\n",
"from torch import optim"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qcx1A4DTaW9a"
},
"source": [
"def opt_func(params, **kwargs): return OptimWrapper(params, torch.optim.SGD, lr=0.001)"
],
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ElZ85UkqxkBV"
},
"source": [
"from fastai.data.core import DataLoaders"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iCJuAvO7x8U1"
},
"source": [
"dls = DataLoaders(trainloader, testloader)"
],
"execution_count": 15,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "kKAndhNwraya"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "_t9OdBCWyt0s"
},
"source": [
"from fastai.learner import Learner"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8v6vrG9IMtT_"
},
"source": [
"from fastai.callback.progress import ProgressCallback"
],
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5f-StN4PyjlS"
},
"source": [
"learn = Learner(dls, net, loss_func=criterion, opt_func=opt_func)"
],
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 134
},
"id": "e7tOMYCv8G4j",
"outputId": "378691be-f5d4-436f-a151-eb5336342d8b"
},
"source": [
"learn.fit(1)"
],
"execution_count": 19,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.111611</td>\n",
" <td>2.077898</td>\n",
" <td>01:04</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "g7c6GjfnoW-C"
},
"source": [
"torch.save(learn.model.state_dict(), './cifar_net.pth')"
],
"execution_count": 20,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment