Last active
February 22, 2017 19:27
-
-
Save ffmpbgrnn/cfbf15472077a3d9ae25f61907db5eb5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.python.ops import parsing_ops | |
from tensorflow.contrib.slim.python.slim.data import parallel_reader | |
import numpy as np | |
def main(): | |
reader = tf.TFRecordReader | |
data_sources = ["train-0.tfrecord"] | |
_, data = parallel_reader.parallel_read( | |
data_sources, | |
reader_class=reader, | |
num_epochs=1, | |
num_readers=1, | |
shuffle=False, | |
capacity=256, | |
min_after_dequeue=1) | |
context_features, sequence_features = parsing_ops.parse_single_sequence_example(data, context_features={ | |
'video_id': tf.VarLenFeature(tf.string), | |
'labels': tf.VarLenFeature(tf.int64), | |
}, sequence_features={ | |
'inc3': tf.FixedLenSequenceFeature(1, tf.string) | |
}, example_name="") | |
with tf.Session() as sess: | |
sess.run(tf.initialize_local_variables()) | |
sess.run(tf.initialize_all_variables()) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
try: | |
while not coord.should_stop(): | |
meta = sess.run(context_features) | |
vid = meta['video_id'].values[0] | |
labels = meta['labels'].values | |
inc3_fea = sess.run(sequence_features)['inc3'] | |
frame_feas = [] | |
for r in inc3_fea: | |
v = np.fromstring(r[0], dtype=np.uint8) | |
frame_feas.append(v[None, :]) | |
frame_feas = np.vstack(frame_feas) | |
print(vid, labels) | |
print(frame_feas.shape) | |
# Do something here | |
except tf.errors.OutOfRangeError: | |
print('Finished extracting.') | |
finally: | |
coord.request_stop() | |
coord.join(threads) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python yt8m_parse.py video /PATH/TO/YOUR/VIDEO/LEVEL/DIR/*.tfrecord | |
python yt8m_parse.py frame /PATH/TO/YOUR/FRAME/LEVEL/DIR/*.tfrecord |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Modified from https://groups.google.com/d/msg/youtube8m-users/yEDzH7EqUf8/EfW0WO3jAgAJ | |
# | |
import sys | |
import tensorflow as tf | |
from tensorflow.python.platform import gfile | |
def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2): | |
"""Dequantize the feature from the byte format to the float format. | |
Args: | |
feat_vector: the input 1-d vector. | |
max_quantized_value: the maximum of the quantized value. | |
min_quantized_value: the minimum of the quantized value. | |
Returns: | |
A float vector which has the same shape as feat_vector. | |
""" | |
assert max_quantized_value > min_quantized_value | |
quantized_range = max_quantized_value - min_quantized_value | |
scalar = quantized_range / 255.0 | |
bias = (quantized_range / 512.0) + min_quantized_value | |
return feat_vector * scalar + bias | |
class YouTube8MFrameFeatureReader: | |
def __init__(self, | |
num_classes=4800, | |
feature_size=1024, | |
feature_name="inc3", | |
max_frames=300, | |
sequence_data=True): | |
self.num_classes = num_classes | |
self.feature_size = feature_size | |
self.feature_name = feature_name | |
self.max_frames = max_frames | |
self.sequence_data = sequence_data | |
def prepare_reader(self, | |
filename_queue, | |
max_quantized_value=2, | |
min_quantized_value=-2): | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
context_features, sequence_features = { | |
"video_id": tf.FixedLenFeature([], tf.string), | |
"labels": tf.VarLenFeature(tf.int64), | |
}, None | |
if self.sequence_data: | |
sequence_features = { | |
self.feature_name: tf.FixedLenSequenceFeature([], dtype=tf.string), | |
} | |
else: | |
context_features[self.feature_name] = tf.FixedLenFeature(self.feature_size, tf.float32) | |
contexts, features = tf.parse_single_sequence_example( | |
serialized_example, | |
context_features=context_features, | |
sequence_features=sequence_features) | |
labels = (tf.cast( | |
tf.sparse_to_dense(contexts["labels"].values, (self.num_classes,), 1), | |
tf.bool)) | |
if self.sequence_data: | |
decoded_features = tf.reshape( | |
tf.cast( | |
tf.decode_raw(features[self.feature_name], tf.uint8), tf.float32), | |
[-1, self.feature_size]) | |
num_frames = tf.minimum(tf.shape(decoded_features)[0], self.max_frames) | |
video_matrix = Dequantize(decoded_features, max_quantized_value, | |
min_quantized_value) | |
else: | |
video_matrix = contexts[self.feature_name] | |
num_frames = tf.constant(-1) | |
# Pad or truncate to 'max_frames' frames. | |
# video_matrix = resize_axis(video_matrix, 0, self.max_frames) | |
return contexts["video_id"], video_matrix, labels, num_frames | |
def main(level, files_pattern): | |
data_files = gfile.Glob(files_pattern) | |
filename_queue = tf.train.string_input_producer( | |
data_files, num_epochs=1, shuffle=False) | |
if level == 'frame': | |
reader = YouTube8MFrameFeatureReader(feature_name="inc3") | |
elif level == 'video': | |
reader = YouTube8MFrameFeatureReader(feature_name="mean_inc3", sequence_data=False) | |
vals = reader.prepare_reader(filename_queue) | |
with tf.Session() as sess: | |
sess.run(tf.initialize_local_variables()) | |
sess.run(tf.initialize_all_variables()) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
try: | |
while not coord.should_stop(): | |
vid, features, labels, _ = sess.run(vals) | |
print(vid, features, labels) | |
except tf.errors.OutOfRangeError: | |
print('Finished extracting.') | |
finally: | |
coord.request_stop() | |
coord.join(threads) | |
if __name__ == '__main__': | |
''' | |
level: 'frame' or 'video' | |
files_pattern: "train*.tfrecord" | |
''' | |
level, files_pattern = sys.argv[1: ] | |
main(level, files_pattern) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I have downloaded a subset of the data inside the tfrecords dir, running:
python yt8m_parse.py video ./tfrecords
creates an error.Any help would be welcome :-P
Output: