Skip to content

Instantly share code, notes, and snippets.

@V0XNIHILI
Created April 8, 2023 13:29
Show Gist options
  • Save V0XNIHILI/09b259c4219210d74532d90b8ae85d16 to your computer and use it in GitHub Desktop.
Save V0XNIHILI/09b259c4219210d74532d90b8ae85d16 to your computer and use it in GitHub Desktop.
Calculate Omniglot Vinyals train, validation and test splits from Protoypical networks code for PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Calculate indices per split from the Vinyals Omniglot split\n",
"\n",
"As used in the Prototypical networks paper."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import urllib\n",
"\n",
"test_url = \"https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/test.txt\"\n",
"val_url = \"https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/val.txt\"\n",
"\n",
"with urllib.request.urlopen(test_url) as response:\n",
" test_html = response.read()\n",
"\n",
"with urllib.request.urlopen(val_url) as response:\n",
" val_html = response.read()\n",
"\n",
"# Parse text\n",
"test_entries = test_html.decode(\"utf-8\").splitlines()\n",
"val_entries = val_html.decode(\"utf-8\").splitlines()\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Only keep characters that end in rot000 and remove the rot000\n",
"test_entries = [x[:-6] for x in test_entries if x.endswith(\"rot000\")]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Only keep characters that end in rot000 and remove the rot000\n",
"val_entries = [x[:-6] for x in val_entries if x.endswith(\"rot000\")]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torchvision"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"background_set = torchvision.datasets.Omniglot('../datasets/data', background=True, download=True)\n",
"non_background_set = torchvision.datasets.Omniglot('../datasets/data', background=False, download=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from os.path import join"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"test_indices = []\n",
"\n",
"for i in range(len(non_background_set)):\n",
" image_name, character_class = non_background_set._flat_character_images[i]\n",
" image_path = join(non_background_set.target_folder, non_background_set._characters[character_class], image_name)\n",
"\n",
" for entry in test_entries:\n",
" if entry in image_path:\n",
" test_indices.append(i + 964*20)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"val_indices = []\n",
"\n",
"for i in range(len(background_set)):\n",
" image_name, character_class = background_set._flat_character_images[i]\n",
" image_path = join(background_set.target_folder, background_set._characters[character_class], image_name)\n",
"\n",
" for entry in val_entries:\n",
" if entry in image_path:\n",
" val_indices.append(i)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Train indices are all the remaining indices between 0 and 1623*20\n",
"train_indices = list(set(range(1623*20)) - set(val_indices) - set(test_indices))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Save the indices to a pickle file\n",
"\n",
"import pickle\n",
"\n",
"with open('omniglot_indices.pkl', 'wb') as f:\n",
" pickle.dump((train_indices, val_indices, test_indices), f)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"all_splits_data = torch.utils.data.ConcatDataset([background_set, non_background_set])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Subset"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"train_set = Subset(all_splits_data, train_indices)\n",
"val_set = Subset(all_splits_data, val_indices)\n",
"test_set = Subset(all_splits_data, test_indices)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "meta_learning_arena",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment