Skip to content

Instantly share code, notes, and snippets.

@beomjunshin-ben
Created December 18, 2018 16:35
Show Gist options
  • Save beomjunshin-ben/641a7306279e5b052fb407115b08dd0d to your computer and use it in GitHub Desktop.
Save beomjunshin-ben/641a7306279e5b052fb407115b08dd0d to your computer and use it in GitHub Desktop.
tf.dataset tutorial
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import numpy as np
class DatasetTutorial():
def __init__(self):
self.num_samples = 100
self.batch_size = 10
self.repeat = 2
def _parse_function(self, image, label):
image *= 1
label *= 1
return image, label
def ours(self):
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples)
self.dataset = self.dataset.batch(self.batch_size)
def case1(self):
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
def case2(self):
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
def case3(self):
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 2)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
def case4(self):
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 4)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
def case5(self):
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 5)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
def case6(self):
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples)
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 2)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
def case7(self):
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples // 2)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
def case8(self):
self.dataset = self.dataset.shuffle(buffer_size=self.num_samples)
self.dataset = self.dataset.map(self._parse_function)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.repeat(self.repeat)
self.dataset = self.dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
def run(self, func):
TF_SESSION_CONFIG = tf.ConfigProto(
gpu_options=tf.GPUOptions(allow_growth=True),
log_device_placement=False,
device_count={"GPU": 1})
grid_size = 2
fig1, axes1 = plt.subplots(nrows=grid_size, ncols=grid_size, figsize=(8, 8))
fig2, axes2 = plt.subplots(nrows=grid_size, ncols=grid_size, figsize=(8, 8))
for n in range(grid_size ** 2):
nrow = n % grid_size
ncol = n // grid_size
print(nrow, ncol)
with tf.Session(config=TF_SESSION_CONFIG) as sess:
images = tf.constant(np.arange(0, self.num_samples, 1))
labels = tf.constant(np.arange(0, self.num_samples, 1))
self.dataset = tf.data.Dataset.from_tensor_slices((images, labels))
eval("self." + func)()
iterator = self.dataset.make_one_shot_iterator()
next_element = iterator.get_next()
assert self.num_samples % self.batch_size == 0
niter = self.num_samples // self.batch_size * self.repeat
image_footprints = np.zeros(shape=(self.num_samples, niter))
label_footprints = np.zeros(shape=(self.num_samples, niter))
for i in range(niter):
image_batch, label_batch = sess.run(next_element)
image_footprints[image_batch, i] = 1
label_footprints[label_batch, i] = 1
np.testing.assert_equal(image_footprints, label_footprints)
sns.heatmap(image_footprints, ax=axes1[nrow, ncol])
sns.heatmap(image_footprints.cumsum(axis=1), ax=axes2[nrow, ncol])
fig1.tight_layout()
fig1.savefig(f"{func}_1.png")
fig2.tight_layout()
fig2.savefig(f"{func}_2.png")
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--method', default="ours", type=str)
parser.add_argument('--total_steps', default="ours", type=str)
args = parser.parse_args()
dataset = DatasetTutorial()
dataset.run(args.method)
@wookayin
Copy link

갓갓

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment