Skip to content

Instantly share code, notes, and snippets.

@rahulvigneswaran
Last active March 13, 2021 13:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rahulvigneswaran/767e869ed4f0b410ffad4501d2252451 to your computer and use it in GitHub Desktop.
Save rahulvigneswaran/767e869ed4f0b410ffad4501d2252451 to your computer and use it in GitHub Desktop.
Custom_Subset [PyTorchForum].ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Custom_Subset [PyTorchForum].ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyONn1ybLf40q8KsNLvnGA3Z",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"70b56c14e8764e1a8bed97f37fd01db6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_ab4117eece134a458a73c16d222523d3",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_2ec05e4ae3bb43a78c362e0920c60d0d",
"IPY_MODEL_1a16348b19234e3dbe5f05f6bd883a68"
]
}
},
"ab4117eece134a458a73c16d222523d3": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"2ec05e4ae3bb43a78c362e0920c60d0d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_f50ab97ea82447abbf528c38d8e7eef5",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 170498071,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 170498071,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_52be1de3b42449d78a293483b276b2ab"
}
},
"1a16348b19234e3dbe5f05f6bd883a68": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_fa756216656b4158b947ee350282c9ff",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 170499072/? [00:05<00:00, 31798468.14it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_10fd55cae20a44d283c2ddd59652c73c"
}
},
"f50ab97ea82447abbf528c38d8e7eef5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"52be1de3b42449d78a293483b276b2ab": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"fa756216656b4158b947ee350282c9ff": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"10fd55cae20a44d283c2ddd59652c73c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/rahulvigneswaran/767e869ed4f0b410ffad4501d2252451/custom_subset-pytorchforum.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 101,
"referenced_widgets": [
"70b56c14e8764e1a8bed97f37fd01db6",
"ab4117eece134a458a73c16d222523d3",
"2ec05e4ae3bb43a78c362e0920c60d0d",
"1a16348b19234e3dbe5f05f6bd883a68",
"f50ab97ea82447abbf528c38d8e7eef5",
"52be1de3b42449d78a293483b276b2ab",
"fa756216656b4158b947ee350282c9ff",
"10fd55cae20a44d283c2ddd59652c73c"
]
},
"id": "rbaxZcvj8XK0",
"outputId": "078c798b-e8d0-4ee0-a2e4-919da99fbca9"
},
"source": [
"import torch\r\n",
"import torchvision\r\n",
"import torchvision.transforms as transforms\r\n",
"\r\n",
"transform = transforms.Compose(\r\n",
" [transforms.ToTensor(),\r\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\r\n",
"\r\n",
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\r\n",
" download=True, transform=transform)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "70b56c14e8764e1a8bed97f37fd01db6",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n",
"Extracting ./data/cifar-10-python.tar.gz to ./data\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "T4Tij92F-uvp"
},
"source": [
"class YourCustomSubset(torch.utils.data.Subset):\r\n",
" r\"\"\"\r\n",
" Same as torch.utils.data.Subset. This outputs the indices of each batch, you can choose which indices to keep.\r\n",
" \"\"\"\r\n",
"\r\n",
" def __getitem__(self, idx):\r\n",
" print(idx)\r\n",
" return self.dataset[self.indices[idx]], idx"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BhW3JO_fAq49",
"outputId": "a251051d-1181-42bf-819e-e3c8a10d9d22"
},
"source": [
"indices = torch.randperm(len(trainset))[:100]\r\n",
"print(f\"Indices Shape: {indices.shape}\")\r\n",
"new_data = YourCustomSubset(trainset, indices)\r\n",
"trainloader = torch.utils.data.DataLoader(new_data, batch_size=400,\r\n",
" shuffle=True, num_workers=2)"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"Indices Shape: torch.Size([100])\n",
"0\n",
"1\n",
"2\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EMfqR3xrA8mp",
"outputId": "e8f1ff51-a08e-440d-c3c2-08ddb028ab87"
},
"source": [
"for epoch in range(2):\r\n",
" best_ind = []\r\n",
" for i, ((x, y), ind) in enumerate(trainloader):\r\n",
" \r\n",
" # Some training here\r\n",
"\r\n",
" best_ind.append(ind[y>5]) # Am just keep collecting all samples that have label>5, you can use whatever condition you want here!\r\n",
"\r\n",
" best_ind = torch.hstack(best_ind)\r\n",
" required_ind = best_ind[:20] # Am just using the first 20 indices, you can have any condition here\r\n",
" indices = torch.randperm(len(trainset))[:80]\r\n",
" indices = torch.cat((indices,required_ind))\r\n",
"\r\n",
" print(f\"Indices Shape for epoch {epoch}: {indices.shape}\") \r\n",
" new_data = YourCustomSubset(trainset, indices)\r\n",
" trainloader = torch.utils.data.DataLoader(new_data, batch_size=4,\r\n",
" shuffle=True, num_workers=2)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"82\n",
"23\n",
"11\n",
"59\n",
"93\n",
"42\n",
"45\n",
"18\n",
"34\n",
"79\n",
"65\n",
"40\n",
"56\n",
"57\n",
"27\n",
"24\n",
"31\n",
"90\n",
"22\n",
"0\n",
"61\n",
"68\n",
"73\n",
"14\n",
"39\n",
"16\n",
"60\n",
"94\n",
"71\n",
"49\n",
"97\n",
"85\n",
"12\n",
"6\n",
"91\n",
"41\n",
"83\n",
"66\n",
"52\n",
"99\n",
"69\n",
"1\n",
"98\n",
"76\n",
"84\n",
"30\n",
"21\n",
"15\n",
"13\n",
"74\n",
"86\n",
"47\n",
"26\n",
"53\n",
"19\n",
"62\n",
"92\n",
"2\n",
"5\n",
"63\n",
"77\n",
"87\n",
"80\n",
"9\n",
"46\n",
"72\n",
"25\n",
"4\n",
"20\n",
"8\n",
"95\n",
"67\n",
"51\n",
"36\n",
"50\n",
"81\n",
"28\n",
"29\n",
"58\n",
"48\n",
"43\n",
"44\n",
"38\n",
"75\n",
"89\n",
"54\n",
"88\n",
"7\n",
"96\n",
"70\n",
"33\n",
"17\n",
"35\n",
"3\n",
"10\n",
"55\n",
"64\n",
"78\n",
"32\n",
"37\n",
"Indices Shape for epoch 0: torch.Size([100])\n",
"68\n",
"59\n",
"53\n",
"44\n",
"88\n",
"5\n",
"19\n",
"38\n",
"67\n",
"20\n",
"8\n",
"84\n",
"94\n",
"22\n",
"78\n",
"63\n",
"86\n",
"24\n",
"92\n",
"75\n",
"37\n",
"91\n",
"0\n",
"29\n",
"98\n",
"25\n",
"48\n",
"10\n",
"90\n",
"79\n",
"2\n",
"45\n",
"15\n",
"65\n",
"14\n",
"85\n",
"83\n",
"35\n",
"16\n",
"95\n",
"73\n",
"40\n",
"50\n",
"33\n",
"80\n",
"18\n",
"60\n",
"93\n",
"49\n",
"77\n",
"46\n",
"28\n",
"74\n",
"7\n",
"64\n",
"47\n",
"4\n",
"1\n",
"81\n",
"54\n",
"3\n",
"12\n",
"99\n",
"56\n",
"21\n",
"27\n",
"82\n",
"62\n",
"23\n",
"71\n",
"13\n",
"70\n",
"51\n",
"96\n",
"30\n",
"17\n",
"57\n",
"32\n",
"34\n",
"66\n",
"76\n",
"52\n",
"72\n",
"42\n",
"11\n",
"26\n",
"43\n",
"55\n",
"9\n",
"89\n",
"87\n",
"41\n",
"36\n",
"39\n",
"31\n",
"97\n",
"69\n",
"58\n",
"6\n",
"61\n",
"Indices Shape for epoch 1: torch.Size([100])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bIPi04FCC8I4"
},
"source": [
""
],
"execution_count": 8,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment