Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Last active May 4, 2021 12:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ariG23498/af76b2b0b2c59cb6eb9aaf90fa75793d to your computer and use it in GitHub Desktop.
Save ariG23498/af76b2b0b2c59cb6eb9aaf90fa75793d to your computer and use it in GitHub Desktop.
ConcatDataset
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "scratchpad",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/af76b2b0b2c59cb6eb9aaf90fa75793d/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "lIYdn1woOS1n"
},
"source": [
"%%bash\n",
"mkdir A B C\n",
"mkdir A/sub_1 A/sub_2 A/sub_3\n",
"mkdir B/sub_1 B/sub_2 B/sub_3\n",
"mkdir C/sub_1 C/sub_2 C/sub_3"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hMENWXMNOoIs"
},
"source": [
"import os\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torchvision"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PhGdo-IUOlTZ"
},
"source": [
"for folder in [\"A\", \"B\", \"C\"]:\n",
" for sub_folder in os.listdir(folder):\n",
" for i in range(2):\n",
" img = np.random.random((20,20))\n",
" plt.imsave(arr=img, fname=f\"{folder}/{sub_folder}/img_{i}.png\")"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "nEKVFXFhOwYD"
},
"source": [
"A_dataset = torchvision.datasets.ImageFolder(root = \"A\" , transform = torchvision.transforms.ToTensor())\n",
"B_dataset = torchvision.datasets.ImageFolder(root = \"B\" , transform = torchvision.transforms.ToTensor())\n",
"C_dataset = torchvision.datasets.ImageFolder(root = \"C\" , transform = torchvision.transforms.ToTensor())\n",
"\n",
"all_datasets = []\n",
"all_datasets.append(A_dataset)\n",
"all_datasets.append(B_dataset)\n",
"all_datasets.append(C_dataset)\n",
"\n",
"final_training_dataset = torch.utils.data.ConcatDataset(all_datasets)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iuoV2AAKVQSC"
},
"source": [
"for ind, c in enumerate(A_dataset.classes):\n",
" A_dataset.classes[ind] = f\"A_{c}\"\n",
"\n",
"for ind, c in enumerate(B_dataset.classes):\n",
" B_dataset.classes[ind] = f\"B_{c}\"\n",
"\n",
"for ind, c in enumerate(C_dataset.classes):\n",
" C_dataset.classes[ind] = f\"C_{c}\""
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eOFGZN60VzvT",
"outputId": "6b8cc188-96b9-4823-9a16-c9bf63babe31"
},
"source": [
"A_dataset.classes, B_dataset.classes, C_dataset.classes, "
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(['A_sub_1', 'A_sub_2', 'A_sub_3'],\n",
" ['B_sub_1', 'B_sub_2', 'B_sub_3'],\n",
" ['C_sub_1', 'C_sub_2', 'C_sub_3'])"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3OaidV76T_3L"
},
"source": [
"full_dl = torch.utils.data.DataLoader(final_training_dataset, batch_size = 1, shuffle = False) "
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "r9pI7wKmO-5s",
"outputId": "29084755-8320-4456-e220-90556dad8353"
},
"source": [
"for idx, element in enumerate(full_dl):\n",
" img, l = element\n",
" if len(A_dataset) - idx >=0:\n",
" print(A_dataset.classes[l])\n",
" elif len(A_dataset)+len(B_dataset) - idx >=0:\n",
" print(B_dataset.classes[l])\n",
" else:\n",
" print(C_dataset.classes[l])"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"A_sub_1\n",
"A_sub_1\n",
"A_sub_2\n",
"A_sub_2\n",
"A_sub_3\n",
"A_sub_3\n",
"A_sub_1\n",
"B_sub_1\n",
"B_sub_2\n",
"B_sub_2\n",
"B_sub_3\n",
"B_sub_3\n",
"B_sub_1\n",
"C_sub_1\n",
"C_sub_2\n",
"C_sub_2\n",
"C_sub_3\n",
"C_sub_3\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment