Created
August 29, 2021 22:49
-
-
Save tttamaki/18dc500448fccc351624e7c27f49f154 to your computer and use it in GitHub Desktop.
pytorch dataset addition.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "pytorch dataset addition.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyMCXS/+39ZR5jm3KSp6pm5m", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/tttamaki/18dc500448fccc351624e7c27f49f154/pytorch-dataset-addition.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Moxr499jf8Rf" | |
}, | |
"source": [ | |
"import torch\n", | |
"import torchvision\n", | |
"import torchvision.transforms as transforms\n", | |
"\n", | |
"from torchvision.datasets import CIFAR10, STL10\n", | |
"from torch.utils.data import DataLoader" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gFKip8yRgCVU", | |
"outputId": "91af5163-42ae-40a5-a064-d036d788cb65" | |
}, | |
"source": [ | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"use_cuda = torch.cuda.is_available()\n", | |
"cudnn.benchmark = True\n", | |
"print('Use CUDA:', use_cuda)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Use CUDA: False\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ipF8x-Mkjkwr" | |
}, | |
"source": [ | |
"transform = transforms.Compose([\n", | |
" transforms.Resize(64),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.5, 0.5, 0.5), \n", | |
" (0.5, 0.5, 0.5)),\n", | |
"])" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4cTpC03-gEyU" | |
}, | |
"source": [ | |
"train_CIFAR10 = CIFAR10(root='./data/',\n", | |
" train=True,\n", | |
" transform=transform,\n", | |
" download=False)\n", | |
"train_STL10 = STL10(root='./data',\n", | |
" download=False,\n", | |
" split='train',\n", | |
" transform=transform\n", | |
" )" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nT0kn4bqgagm" | |
}, | |
"source": [ | |
"add_dataset = train_CIFAR10 + train_STL10\n", | |
"\n", | |
"batch_size = 2\n", | |
"\n", | |
"train_loader = DataLoader(add_dataset,\n", | |
" batch_size=batch_size,\n", | |
" shuffle=True)\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "HO5cIIDdhL78", | |
"outputId": "754031ef-4979-4bda-afa1-63de90d97ad6" | |
}, | |
"source": [ | |
"for i, (data, label) in enumerate(train_loader):\n", | |
" print(data.shape, label)\n", | |
" if i > 20:\n", | |
" break" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([2, 3, 64, 64]) tensor([1, 6])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([6, 4])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([4, 4])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([5, 3])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([9, 5])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([1, 9])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([8, 3])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([2, 8])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([9, 3])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([7, 1])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([8, 8])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([4, 8])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([9, 3])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([3, 2])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([4, 1])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([8, 1])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([1, 0])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([2, 5])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([8, 1])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([8, 5])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([0, 0])\n", | |
"torch.Size([2, 3, 64, 64]) tensor([3, 8])\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "V3OD_PiqiSDM" | |
}, | |
"source": [ | |
"class MyCIFAR10(torchvision.datasets.CIFAR10):\n", | |
"\n", | |
" def __getitem__(self, index):\n", | |
" img, target = super().__getitem__(index)\n", | |
" return img, target, 'CIFAR10'\n", | |
"\n", | |
"class MySTL10(torchvision.datasets.STL10):\n", | |
"\n", | |
" def __getitem__(self, index):\n", | |
" img, target = super().__getitem__(index)\n", | |
" return img, target, 'STL10'" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kYxzxGw_kGam" | |
}, | |
"source": [ | |
"train_MyCIFAR10 = MyCIFAR10(root='./data/',\n", | |
" train=True, \n", | |
" transform=transform,\n", | |
" download=False)\n", | |
"train_MySTL10 = MySTL10(root='./data', \n", | |
" download=False, \n", | |
" split='train',\n", | |
" transform=transform\n", | |
" )" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "2W9l-gimkS3u", | |
"outputId": "6b28c20a-0c76-4a63-daef-ac1cfa00a1b1" | |
}, | |
"source": [ | |
"len(train_MyCIFAR10), len(train_MySTL10)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(50000, 5000)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 90 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nsHbeBSslgY-" | |
}, | |
"source": [ | |
"add_dataset = train_MyCIFAR10 + train_MySTL10\n", | |
"\n", | |
"batch_size = 2\n", | |
"\n", | |
"train_loader = DataLoader(add_dataset,\n", | |
" batch_size=batch_size,\n", | |
" shuffle=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "19-iuNFekUz4", | |
"outputId": "f74b56e9-b8cf-407c-c570-73dcdecee051" | |
}, | |
"source": [ | |
"for i, (data, label, dataset_name) in enumerate(train_loader):\n", | |
" print(data.shape, label, dataset_name)\n", | |
" if i > 40:\n", | |
" break" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([2, 3, 64, 64]) tensor([4, 9]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([7, 9]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([3, 6]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([2, 9]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([7, 4]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([2, 2]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([4, 6]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([5, 8]) ('STL10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([3, 4]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([1, 9]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([2, 7]) ('STL10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([6, 4]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([2, 3]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([7, 8]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([2, 7]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([9, 4]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([9, 5]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([8, 2]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([0, 1]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([1, 4]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([0, 1]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([9, 7]) ('CIFAR10', 'STL10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([7, 4]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([3, 9]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([8, 1]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([7, 1]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([4, 2]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([1, 3]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([9, 0]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([3, 5]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([1, 1]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([7, 9]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([6, 0]) ('CIFAR10', 'STL10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([6, 6]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([3, 5]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([5, 1]) ('CIFAR10', 'STL10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([3, 4]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([7, 6]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([5, 0]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([8, 8]) ('STL10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([4, 6]) ('CIFAR10', 'CIFAR10')\n", | |
"torch.Size([2, 3, 64, 64]) tensor([9, 7]) ('CIFAR10', 'CIFAR10')\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QMNe23LslF1u" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment