Skip to content

Instantly share code, notes, and snippets.

@tttamaki
Created August 29, 2021 22:49
Show Gist options
  • Save tttamaki/18dc500448fccc351624e7c27f49f154 to your computer and use it in GitHub Desktop.
Save tttamaki/18dc500448fccc351624e7c27f49f154 to your computer and use it in GitHub Desktop.
pytorch dataset addition.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "pytorch dataset addition.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMCXS/+39ZR5jm3KSp6pm5m",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tttamaki/18dc500448fccc351624e7c27f49f154/pytorch-dataset-addition.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Moxr499jf8Rf"
},
"source": [
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"\n",
"from torchvision.datasets import CIFAR10, STL10\n",
"from torch.utils.data import DataLoader"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gFKip8yRgCVU",
"outputId": "91af5163-42ae-40a5-a064-d036d788cb65"
},
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"use_cuda = torch.cuda.is_available()\n",
"cudnn.benchmark = True\n",
"print('Use CUDA:', use_cuda)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Use CUDA: False\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ipF8x-Mkjkwr"
},
"source": [
"transform = transforms.Compose([\n",
" transforms.Resize(64),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), \n",
" (0.5, 0.5, 0.5)),\n",
"])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4cTpC03-gEyU"
},
"source": [
"train_CIFAR10 = CIFAR10(root='./data/',\n",
" train=True,\n",
" transform=transform,\n",
" download=False)\n",
"train_STL10 = STL10(root='./data',\n",
" download=False,\n",
" split='train',\n",
" transform=transform\n",
" )"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "nT0kn4bqgagm"
},
"source": [
"add_dataset = train_CIFAR10 + train_STL10\n",
"\n",
"batch_size = 2\n",
"\n",
"train_loader = DataLoader(add_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True)\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HO5cIIDdhL78",
"outputId": "754031ef-4979-4bda-afa1-63de90d97ad6"
},
"source": [
"for i, (data, label) in enumerate(train_loader):\n",
" print(data.shape, label)\n",
" if i > 20:\n",
" break"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([2, 3, 64, 64]) tensor([1, 6])\n",
"torch.Size([2, 3, 64, 64]) tensor([6, 4])\n",
"torch.Size([2, 3, 64, 64]) tensor([4, 4])\n",
"torch.Size([2, 3, 64, 64]) tensor([5, 3])\n",
"torch.Size([2, 3, 64, 64]) tensor([9, 5])\n",
"torch.Size([2, 3, 64, 64]) tensor([1, 9])\n",
"torch.Size([2, 3, 64, 64]) tensor([8, 3])\n",
"torch.Size([2, 3, 64, 64]) tensor([2, 8])\n",
"torch.Size([2, 3, 64, 64]) tensor([9, 3])\n",
"torch.Size([2, 3, 64, 64]) tensor([7, 1])\n",
"torch.Size([2, 3, 64, 64]) tensor([8, 8])\n",
"torch.Size([2, 3, 64, 64]) tensor([4, 8])\n",
"torch.Size([2, 3, 64, 64]) tensor([9, 3])\n",
"torch.Size([2, 3, 64, 64]) tensor([3, 2])\n",
"torch.Size([2, 3, 64, 64]) tensor([4, 1])\n",
"torch.Size([2, 3, 64, 64]) tensor([8, 1])\n",
"torch.Size([2, 3, 64, 64]) tensor([1, 0])\n",
"torch.Size([2, 3, 64, 64]) tensor([2, 5])\n",
"torch.Size([2, 3, 64, 64]) tensor([8, 1])\n",
"torch.Size([2, 3, 64, 64]) tensor([8, 5])\n",
"torch.Size([2, 3, 64, 64]) tensor([0, 0])\n",
"torch.Size([2, 3, 64, 64]) tensor([3, 8])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "V3OD_PiqiSDM"
},
"source": [
"class MyCIFAR10(torchvision.datasets.CIFAR10):\n",
"\n",
" def __getitem__(self, index):\n",
" img, target = super().__getitem__(index)\n",
" return img, target, 'CIFAR10'\n",
"\n",
"class MySTL10(torchvision.datasets.STL10):\n",
"\n",
" def __getitem__(self, index):\n",
" img, target = super().__getitem__(index)\n",
" return img, target, 'STL10'"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kYxzxGw_kGam"
},
"source": [
"train_MyCIFAR10 = MyCIFAR10(root='./data/',\n",
" train=True, \n",
" transform=transform,\n",
" download=False)\n",
"train_MySTL10 = MySTL10(root='./data', \n",
" download=False, \n",
" split='train',\n",
" transform=transform\n",
" )"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2W9l-gimkS3u",
"outputId": "6b28c20a-0c76-4a63-daef-ac1cfa00a1b1"
},
"source": [
"len(train_MyCIFAR10), len(train_MySTL10)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(50000, 5000)"
]
},
"metadata": {
"tags": []
},
"execution_count": 90
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "nsHbeBSslgY-"
},
"source": [
"add_dataset = train_MyCIFAR10 + train_MySTL10\n",
"\n",
"batch_size = 2\n",
"\n",
"train_loader = DataLoader(add_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19-iuNFekUz4",
"outputId": "f74b56e9-b8cf-407c-c570-73dcdecee051"
},
"source": [
"for i, (data, label, dataset_name) in enumerate(train_loader):\n",
" print(data.shape, label, dataset_name)\n",
" if i > 40:\n",
" break"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([2, 3, 64, 64]) tensor([4, 9]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([7, 9]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([3, 6]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([2, 9]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([7, 4]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([2, 2]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([4, 6]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([5, 8]) ('STL10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([3, 4]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([1, 9]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([2, 7]) ('STL10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([6, 4]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([2, 3]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([7, 8]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([2, 7]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([9, 4]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([9, 5]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([8, 2]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([0, 1]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([1, 4]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([0, 1]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([9, 7]) ('CIFAR10', 'STL10')\n",
"torch.Size([2, 3, 64, 64]) tensor([7, 4]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([3, 9]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([8, 1]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([7, 1]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([4, 2]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([1, 3]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([9, 0]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([3, 5]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([1, 1]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([7, 9]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([6, 0]) ('CIFAR10', 'STL10')\n",
"torch.Size([2, 3, 64, 64]) tensor([6, 6]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([3, 5]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([5, 1]) ('CIFAR10', 'STL10')\n",
"torch.Size([2, 3, 64, 64]) tensor([3, 4]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([7, 6]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([5, 0]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([8, 8]) ('STL10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([4, 6]) ('CIFAR10', 'CIFAR10')\n",
"torch.Size([2, 3, 64, 64]) tensor([9, 7]) ('CIFAR10', 'CIFAR10')\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QMNe23LslF1u"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment