Skip to content

Instantly share code, notes, and snippets.

@staceysv
Last active April 10, 2020 17:39
Show Gist options
  • Save staceysv/e96f0e1a7b8fd7aafa6eb8937a95b37a to your computer and use it in GitHub Desktop.
Save staceysv/e96f0e1a7b8fd7aafa6eb8937a95b37a to your computer and use it in GitHub Desktop.
organize mini-ImageNet files for MAML training
#!/usr/bin/env python
import csv
from PIL import Image
import pickle
import os
img_size = 84
test_csv_file = "../../Code/few-shot-ssl-public/fewshot/data/mini_imagenet_split/Ravi/test.csv"
train_csv_file = "../../Code/few-shot-ssl-public/fewshot/data/mini_imagenet_split/Ravi/train.csv"
val_csv_file = "../../Code/few-shot-ssl-public/fewshot/data/mini_imagenet_split/Ravi/val.csv"
test_pkl = "mini-imagenet-cache-test.pkl"
train_pkl = "mini-imagenet-cache-train.pkl"
val_pkl = "mini-imagenet-cache-val.pkl"
def pkl_to_raw(csv_file, pkl_file, dirname):
if not os.path.isdir(dirname):
os.mkdir(dirname)
print("made: ", dirname)
p = pickle.load(open(pkl_file, 'rb'))
imgs = p["image_data"]
csv_r = csv.reader(open(csv_file))
i = -1
last_label = ""
for (image_filename, class_name) in csv_r:
# skip headers
if i == -1:
i += 1
continue
im = Image.fromarray(imgs[i])
# resize as in maml source code
im = im.resize((img_size, img_size), resample=Image.LANCZOS)
# make a new subdir for every label
if class_name != last_label:
os.mkdir(dirname + "/" + class_name)
last_label = class_name
# save image file into correct class folder
new_filename = dirname + "/" + class_name + "/" + image_filename
im.save(new_filename)
i += 1
if i % 500 == 0:
print(i)
pkl_to_raw(train_csv_file, train_pkl, "train")
pkl_to_raw(val_csv_file, val_pkl, "val")
pkl_to_raw(test_csv_file, test_pkl, "test")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment