Skip to content

Instantly share code, notes, and snippets.

View ChunML's full-sized avatar
✍️
Landed new site at https://trungtran.io

Trung Tran ChunML

✍️
Landed new site at https://trungtran.io
View GitHub Profile
def one_shot_input_fn(filenames, labels):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_data).batch(1)
iterator = dataset.make_one_shot_iterator()
img, label = iterator.get_next()
return img, label
means = [123.68, 116.779, 103.939]
def _parse_data(filename, label, new_size=224):
img_string = tf.read_file(filename)
img = tf.image.decode_jpeg(img_string)
img = tf.image.resize_images(img, (new_size, new_size))
img.set_shape([new_size, new_size, 3])
img = tf.to_float(img)
channels = tf.split(axis=2, num_or_size_splits=3, value=img)
for i in range(3):
def visualize_dataset(imgs, labels):
fig, axes = plt.subplots(ncols=4, nrows=4)
fig.set_size_inches(10, 10)
for i, img in enumerate(imgs):
img += means
np.clip(img, 0, 255, img)
img = img[0].astype(np.uint8)
img = Image.fromarray(img)
img, label = one_shot_input_fn(filenames, labels)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
res_imgs = []
res_labels = []
for _ in range(16):
res_img, res_label = sess.run([img, label])
res_imgs.append(res_img)
res_labels.append(res_label)
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
def reinitializable_input_fn(filenames, labels, train_val_ratio=0.8):
num_files = len(filenames)
num_train_files = int(num_files * train_val_ratio)
train_filenames = filenames[:num_train_files]
train_labels = labels[:num_train_files]
val_filenames = filenames[num_train_files:]
val_labels = labels[num_train_files:]
train_data = tf.data.Dataset.from_tensor_slices(
(train_filenames, train_labels))
train_data = train_data.map(_parse_data).shuffle(1000).repeat().batch(4)
val_data = tf.data.Dataset.from_tensor_slices(
(val_filenames, val_labels))
val_data = val_data.map(_parse_data).batch(1)
iterator = tf.data.Iterator.from_structure(train_data.output_types,
train_data.output_shapes)
next_element = iterator.get_next()