Last active
March 13, 2021 13:18
-
-
Save rahulvigneswaran/767e869ed4f0b410ffad4501d2252451 to your computer and use it in GitHub Desktop.
Custom_Subset [PyTorchForum].ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Custom_Subset [PyTorchForum].ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyONn1ybLf40q8KsNLvnGA3Z", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"70b56c14e8764e1a8bed97f37fd01db6": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"state": { | |
"_view_name": "HBoxView", | |
"_dom_classes": [], | |
"_model_name": "HBoxModel", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"box_style": "", | |
"layout": "IPY_MODEL_ab4117eece134a458a73c16d222523d3", | |
"_model_module": "@jupyter-widgets/controls", | |
"children": [ | |
"IPY_MODEL_2ec05e4ae3bb43a78c362e0920c60d0d", | |
"IPY_MODEL_1a16348b19234e3dbe5f05f6bd883a68" | |
] | |
} | |
}, | |
"ab4117eece134a458a73c16d222523d3": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"2ec05e4ae3bb43a78c362e0920c60d0d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"state": { | |
"_view_name": "ProgressView", | |
"style": "IPY_MODEL_f50ab97ea82447abbf528c38d8e7eef5", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "FloatProgressModel", | |
"bar_style": "success", | |
"max": 170498071, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 170498071, | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"orientation": "horizontal", | |
"min": 0, | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_52be1de3b42449d78a293483b276b2ab" | |
} | |
}, | |
"1a16348b19234e3dbe5f05f6bd883a68": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"state": { | |
"_view_name": "HTMLView", | |
"style": "IPY_MODEL_fa756216656b4158b947ee350282c9ff", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "HTMLModel", | |
"placeholder": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": " 170499072/? [00:05<00:00, 31798468.14it/s]", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_10fd55cae20a44d283c2ddd59652c73c" | |
} | |
}, | |
"f50ab97ea82447abbf528c38d8e7eef5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ProgressStyleModel", | |
"description_width": "initial", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"bar_color": null, | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"52be1de3b42449d78a293483b276b2ab": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"fa756216656b4158b947ee350282c9ff": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "DescriptionStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"10fd55cae20a44d283c2ddd59652c73c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/rahulvigneswaran/767e869ed4f0b410ffad4501d2252451/custom_subset-pytorchforum.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 101, | |
"referenced_widgets": [ | |
"70b56c14e8764e1a8bed97f37fd01db6", | |
"ab4117eece134a458a73c16d222523d3", | |
"2ec05e4ae3bb43a78c362e0920c60d0d", | |
"1a16348b19234e3dbe5f05f6bd883a68", | |
"f50ab97ea82447abbf528c38d8e7eef5", | |
"52be1de3b42449d78a293483b276b2ab", | |
"fa756216656b4158b947ee350282c9ff", | |
"10fd55cae20a44d283c2ddd59652c73c" | |
] | |
}, | |
"id": "rbaxZcvj8XK0", | |
"outputId": "078c798b-e8d0-4ee0-a2e4-919da99fbca9" | |
}, | |
"source": [ | |
"import torch\r\n", | |
"import torchvision\r\n", | |
"import torchvision.transforms as transforms\r\n", | |
"\r\n", | |
"transform = transforms.Compose(\r\n", | |
" [transforms.ToTensor(),\r\n", | |
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\r\n", | |
"\r\n", | |
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\r\n", | |
" download=True, transform=transform)" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "70b56c14e8764e1a8bed97f37fd01db6", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Extracting ./data/cifar-10-python.tar.gz to ./data\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "T4Tij92F-uvp" | |
}, | |
"source": [ | |
"class YourCustomSubset(torch.utils.data.Subset):\r\n", | |
" r\"\"\"\r\n", | |
" Same as torch.utils.data.Subset. This outputs the indices of each batch, you can choose which indices to keep.\r\n", | |
" \"\"\"\r\n", | |
"\r\n", | |
" def __getitem__(self, idx):\r\n", | |
" print(idx)\r\n", | |
" return self.dataset[self.indices[idx]], idx" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "BhW3JO_fAq49", | |
"outputId": "a251051d-1181-42bf-819e-e3c8a10d9d22" | |
}, | |
"source": [ | |
"indices = torch.randperm(len(trainset))[:100]\r\n", | |
"print(f\"Indices Shape: {indices.shape}\")\r\n", | |
"new_data = YourCustomSubset(trainset, indices)\r\n", | |
"trainloader = torch.utils.data.DataLoader(new_data, batch_size=400,\r\n", | |
" shuffle=True, num_workers=2)" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Indices Shape: torch.Size([100])\n", | |
"0\n", | |
"1\n", | |
"2\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "EMfqR3xrA8mp", | |
"outputId": "e8f1ff51-a08e-440d-c3c2-08ddb028ab87" | |
}, | |
"source": [ | |
"for epoch in range(2):\r\n", | |
" best_ind = []\r\n", | |
" for i, ((x, y), ind) in enumerate(trainloader):\r\n", | |
" \r\n", | |
" # Some training here\r\n", | |
"\r\n", | |
" best_ind.append(ind[y>5]) # Am just keep collecting all samples that have label>5, you can use whatever condition you want here!\r\n", | |
"\r\n", | |
" best_ind = torch.hstack(best_ind)\r\n", | |
" required_ind = best_ind[:20] # Am just using the first 20 indices, you can have any condition here\r\n", | |
" indices = torch.randperm(len(trainset))[:80]\r\n", | |
" indices = torch.cat((indices,required_ind))\r\n", | |
"\r\n", | |
" print(f\"Indices Shape for epoch {epoch}: {indices.shape}\") \r\n", | |
" new_data = YourCustomSubset(trainset, indices)\r\n", | |
" trainloader = torch.utils.data.DataLoader(new_data, batch_size=4,\r\n", | |
" shuffle=True, num_workers=2)" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"82\n", | |
"23\n", | |
"11\n", | |
"59\n", | |
"93\n", | |
"42\n", | |
"45\n", | |
"18\n", | |
"34\n", | |
"79\n", | |
"65\n", | |
"40\n", | |
"56\n", | |
"57\n", | |
"27\n", | |
"24\n", | |
"31\n", | |
"90\n", | |
"22\n", | |
"0\n", | |
"61\n", | |
"68\n", | |
"73\n", | |
"14\n", | |
"39\n", | |
"16\n", | |
"60\n", | |
"94\n", | |
"71\n", | |
"49\n", | |
"97\n", | |
"85\n", | |
"12\n", | |
"6\n", | |
"91\n", | |
"41\n", | |
"83\n", | |
"66\n", | |
"52\n", | |
"99\n", | |
"69\n", | |
"1\n", | |
"98\n", | |
"76\n", | |
"84\n", | |
"30\n", | |
"21\n", | |
"15\n", | |
"13\n", | |
"74\n", | |
"86\n", | |
"47\n", | |
"26\n", | |
"53\n", | |
"19\n", | |
"62\n", | |
"92\n", | |
"2\n", | |
"5\n", | |
"63\n", | |
"77\n", | |
"87\n", | |
"80\n", | |
"9\n", | |
"46\n", | |
"72\n", | |
"25\n", | |
"4\n", | |
"20\n", | |
"8\n", | |
"95\n", | |
"67\n", | |
"51\n", | |
"36\n", | |
"50\n", | |
"81\n", | |
"28\n", | |
"29\n", | |
"58\n", | |
"48\n", | |
"43\n", | |
"44\n", | |
"38\n", | |
"75\n", | |
"89\n", | |
"54\n", | |
"88\n", | |
"7\n", | |
"96\n", | |
"70\n", | |
"33\n", | |
"17\n", | |
"35\n", | |
"3\n", | |
"10\n", | |
"55\n", | |
"64\n", | |
"78\n", | |
"32\n", | |
"37\n", | |
"Indices Shape for epoch 0: torch.Size([100])\n", | |
"68\n", | |
"59\n", | |
"53\n", | |
"44\n", | |
"88\n", | |
"5\n", | |
"19\n", | |
"38\n", | |
"67\n", | |
"20\n", | |
"8\n", | |
"84\n", | |
"94\n", | |
"22\n", | |
"78\n", | |
"63\n", | |
"86\n", | |
"24\n", | |
"92\n", | |
"75\n", | |
"37\n", | |
"91\n", | |
"0\n", | |
"29\n", | |
"98\n", | |
"25\n", | |
"48\n", | |
"10\n", | |
"90\n", | |
"79\n", | |
"2\n", | |
"45\n", | |
"15\n", | |
"65\n", | |
"14\n", | |
"85\n", | |
"83\n", | |
"35\n", | |
"16\n", | |
"95\n", | |
"73\n", | |
"40\n", | |
"50\n", | |
"33\n", | |
"80\n", | |
"18\n", | |
"60\n", | |
"93\n", | |
"49\n", | |
"77\n", | |
"46\n", | |
"28\n", | |
"74\n", | |
"7\n", | |
"64\n", | |
"47\n", | |
"4\n", | |
"1\n", | |
"81\n", | |
"54\n", | |
"3\n", | |
"12\n", | |
"99\n", | |
"56\n", | |
"21\n", | |
"27\n", | |
"82\n", | |
"62\n", | |
"23\n", | |
"71\n", | |
"13\n", | |
"70\n", | |
"51\n", | |
"96\n", | |
"30\n", | |
"17\n", | |
"57\n", | |
"32\n", | |
"34\n", | |
"66\n", | |
"76\n", | |
"52\n", | |
"72\n", | |
"42\n", | |
"11\n", | |
"26\n", | |
"43\n", | |
"55\n", | |
"9\n", | |
"89\n", | |
"87\n", | |
"41\n", | |
"36\n", | |
"39\n", | |
"31\n", | |
"97\n", | |
"69\n", | |
"58\n", | |
"6\n", | |
"61\n", | |
"Indices Shape for epoch 1: torch.Size([100])\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bIPi04FCC8I4" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment