Skip to content

Instantly share code, notes, and snippets.

@cinjon
Last active March 27, 2020 17:40
Show Gist options
  • Save cinjon/0017fdb9044903caaf54a9c338413119 to your computer and use it in GitHub Desktop.
Save cinjon/0017fdb9044903caaf54a9c338413119 to your computer and use it in GitHub Desktop.
import os
import gin
from meta_dataset.data import config
from meta_dataset.data import dataset_spec as dataset_spec_lib
from meta_dataset.data import learning_spec
from meta_dataset.data import pipeline
import numpy as np
import tensorflow as tf
import torch
GIN_FILE_PATH = 'metadataset/meta_dataset/learn/gin/setups/data_config.gin'
ALL_DATASETS = [
'aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012', 'omniglot',
'quickdraw', 'vgg_flower'
]
gin.parse_config_file(GIN_FILE_PATH)
# Comment out to disable eager execution.
tf.enable_eager_execution()
np_to_torch_labels = lambda a: torch.from_numpy(a.numpy()).long()
np_to_torch_imgs = lambda a: torch.from_numpy(
np.transpose(a.numpy(), (0, 3, 2, 1)))
to_torch_labels = lambda a: torch.from_numpy(a).long()
to_torch_imgs = lambda a: torch.from_numpy(np.transpose(a, (0, 3, 2, 1)))
def iterate_dataset(dataset, num_batches, batch_size):
if not tf.executing_eagerly():
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for idx in range(num_batches):
episode, source_id = sess.run(next_element)
yield (to_torch_imgs(episode[0]), to_torch_labels(episode[1]),
to_torch_imgs(episode[3]), to_torch_labels(episode[4]))
else:
batch_count = 0
curr_batch = []
for idx, (episode, source_id) in enumerate(dataset):
if batch_count == num_batches:
break
batch_entry = [
np_to_torch_imgs(episode[0]), np_to_torch_labels(episode[1]),
np_to_torch_imgs(episode[3]), np_to_torch_labels(episode[4])
]
curr_batch.append(batch_entry)
if len(curr_batch) == batch_size:
data_support = torch.stack([k[0] for k in curr_batch])
labels_support = torch.stack([k[1] for k in curr_batch])
data_query = torch.stack([k[2] for k in curr_batch])
labels_query = torch.stack([k[3] for k in curr_batch])
curr_batch = []
batch_count += 1
yield data_support, labels_support, data_query, labels_query
def pytorch_loader(fixed=True,
train=False,
test=False,
valid=False,
dataset=None,
num_support=None,
num_ways=None,
num_query=None,
batch_size=16,
num_batches=2,
base_path=None):
"""Pytorch loader.
We use the fixed ways and shots approach. See the repo for the others.
"""
print('Dataset: ', dataset)
if not train and not test and not valid:
raise
if train:
split = learning_spec.Split.TRAIN
elif test:
split = learning_spec.Split.TEST
elif valid:
split = learning_spec.Split.VALID
dataset_records_path = os.path.join(base_path, dataset)
dataset_spec = [dataset_spec_lib.load_dataset_spec(dataset_records_path)]
fixed_ways_shots = config.EpisodeDescriptionConfig(num_ways=num_ways,
num_support=num_support,
num_query=num_query)
dataset = pipeline.make_multisource_episode_pipeline(
dataset_spec_list=dataset_spec,
use_dag_ontology_list=[False],
use_bilevel_ontology_list=[False] * len(ALL_DATASETS),
split=split,
image_size=84,
episode_descr_config=fixed_ways_shots)
return iterate_dataset(dataset, num_batches, batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment