Created
May 30, 2018 11:05
-
-
Save okwrtdsh/c4e21ee1be866827f1556748ef9688d5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
inspired by | |
https://github.com/harvitronix/five-video-classification-methods/blob/master/data.py | |
""" | |
import operator | |
import os | |
import random | |
import threading | |
from glob import glob | |
import numpy as np | |
from keras.preprocessing.image import img_to_array, load_img | |
from keras.utils import to_categorical | |
random.seed(0) | |
np.random.seed(0) | |
ORIGIN_SHAPE = (128, 171, 3) | |
TARGET_SHAPE = (112, 112, 3) | |
def process_image(image, crop=(8, 120, 29, 141), flip=False): | |
"""Given an image, process it and return the array.""" | |
cw1, cw2, ch1, ch2 = crop | |
h, w, _ = ORIGIN_SHAPE | |
image = load_img(image, target_size=(h, w)) | |
img_arr = img_to_array(image) | |
x = (img_arr / 255.).astype(np.float32) | |
if flip: | |
x = np.fliplr(x) | |
return x[cw1:cw2, ch1:ch2, :] | |
def chunks(l, n=16, w=0): | |
""" | |
Args: | |
l (Iterable): list like object | |
n (int): split by n | |
w (int): overwrap | |
Returns: | |
Generator: splited by n | |
""" | |
_l = len(l) | |
for i in range(0, _l, n-w): | |
if _l - i < n-w: | |
continue | |
yield l[i:i + n] | |
class ThreadsafeIterator(object): | |
def __init__(self, iterator): | |
self.iterator = iterator | |
self.lock = threading.Lock() | |
def __iter__(self): | |
return self | |
def __next__(self): | |
with self.lock: | |
return next(self.iterator) | |
def threadsafe_generator(func): | |
def gen(*a, **kw): | |
return ThreadsafeIterator(func(*a, **kw)) | |
return gen | |
class UCF101DataSet(object): | |
def __init__( | |
self, | |
seq_length=16, | |
seq_overwrap=0, | |
num_split=1, | |
class_limit=None): | |
""" | |
Args: | |
seq_length (int): the number of frames to consider | |
seq_overwrap (int): the number of overwrap frames to consider | |
num_split (int): the number of split type (1-3) | |
class_limit (int): the number of classes to limit the data to | |
None = no limit. | |
""" | |
self.seq_length = seq_length | |
self.seq_overwrap = seq_overwrap | |
self.class_limit = class_limit | |
assert 1 <= num_split <= 3, 'must be 1 <= num_split <= 3' | |
self.num_split = num_split | |
self.sequence_path = os.path.join('data', 'sequences') | |
# max number of frames a video can have for us to use it | |
self.max_frame_limit = 300 | |
self.load_classes() | |
self.load_train_data() | |
self.load_test_data() | |
def get_data(self, file_path): | |
""" | |
Returns: | |
[( | |
filename, | |
cls_id, | |
[['path/to/frame_xxx.png',],] | |
),] | |
""" | |
datalist = [] | |
with open(file_path, 'r') as f: | |
for line in f: | |
video_path = line.rstrip().split(' ')[0] | |
cls_id = self.r_actions.get(video_path.split('/')[0]) | |
if cls_id is None: | |
continue | |
img_path, _ = os.path.splitext( | |
os.path.join('./UCF-101/', video_path)) | |
imgs = glob(os.path.join(img_path, '*.png')) | |
if len(imgs) >= self.max_frame_limit: | |
continue | |
datalist.append(( | |
video_path.split('/')[-1], | |
cls_id, | |
list(chunks(imgs, self.seq_length, self.seq_overwrap)))) | |
return datalist | |
def load_train_data(self): | |
self.train_data = self.get_data('./ucfTrainTestlist/trainlist01.txt') | |
def load_test_data(self): | |
self.test_data = self.get_data('./ucfTrainTestlist/testlist01.txt') | |
def load_classes(self): | |
actions = {} | |
r_actions = {} | |
cnt = 0 | |
with open('./ucfTrainTestlist/classInd.txt', 'r') as f: | |
for line in f: | |
cls_id, cls_name = line.rstrip().split(' ') | |
cls_id = int(cls_id) - 1 | |
if self.class_limit is not None and cls_id >= self.class_limit: | |
continue | |
actions[cls_id] = cls_name | |
r_actions[cls_name] = cls_id | |
cnt += 1 | |
self.actions = actions | |
self.r_actions = r_actions | |
self.num_classes = cnt | |
def get_all_sequences_in_memory(self, train_test, data_type): | |
""" | |
This is a mirror of our generator, but attempts to load everything into | |
memory so we can train way faster. | |
""" | |
# Get the right dataset. | |
data = self.train_data if train_test == 'train' else self.test_data | |
print("Loading %d samples into memory for %sing." % ( | |
len(data), train_test)) | |
X, y = [], [] | |
for filename, cls_id, frames_list in data: | |
# (filename, cls_id, [['path/to/frame_xxx.png',],]) | |
if data_type == 'images': | |
for frames in frames_list: | |
# Build the image sequence | |
X.append(self.build_image_sequence( | |
frames, crop_type='fixed')) | |
y.append(cls_id) | |
else: | |
sequence = self.get_extracted_sequence(data_type, filename) | |
if sequence is None: | |
print("Can't find sequence. Did you generate them?") | |
raise | |
X.append(sequence) | |
y.append(cls_id) | |
return np.array(X), to_categorical(np.array(y), self.num_classes) | |
# @threadsafe_generator | |
def frame_generator(self, batch_size, train_test, data_type): | |
"""Return a generator that we can use to train on. There are | |
a couple different things we can return: | |
data_type: 'features', 'images' | |
""" | |
data = self.train_data if train_test == 'train' else self.test_data | |
print("Creating %s generator with %d samples." % ( | |
train_test, len(data))) | |
while True: | |
X, y = [], [] | |
for _ in range(batch_size): | |
sequence = None | |
# Get a random sample. | |
sample = random.choice(data) | |
# Check to see if we've already saved this sequence. | |
filename, cls_id, frames_list = sample | |
if data_type is "images": | |
# Get and resample frames. | |
frames = random.choice(frames_list) | |
# Build the image sequence | |
X.append(self.build_image_sequence( | |
frames, crop_type='random')) | |
y.append(cls_id) | |
else: | |
sequence = self.get_extracted_sequence(data_type, filename) | |
if sequence is None: | |
print("Can't find sequence. Did you generate them?") | |
raise | |
X.append(sequence) | |
y.append(cls_id) | |
yield np.array(X), to_categorical(np.array(y), self.num_classes) | |
def build_image_sequence(self, frames, crop_type='random'): | |
"""Given a set of frames (filenames), build our sequence.""" | |
if crop_type == 'fixed': | |
dw, dh, _ = (np.array(ORIGIN_SHAPE) - np.array(TARGET_SHAPE)) // 2 | |
flip = False | |
elif crop_type == 'random': | |
dw, dh, _ = (np.array(ORIGIN_SHAPE) - np.array(TARGET_SHAPE)) | |
dw = int(random.random()*dw+0.5) | |
dh = int(random.random()*dh+0.5) | |
flip = int(random.random()+0.5) == 1 | |
w, h, _ = TARGET_SHAPE | |
return list(map( | |
lambda x: process_image(x, crop=(dw, w+dw, dh, h+dh), flip=flip), frames)) | |
def get_extracted_sequence(self, data_type, filename): | |
"""Get the saved extracted features.""" | |
path = os.path.join( | |
self.sequence_path, | |
filename + '-' + str(self.seq_length) + '-' + data_type + '.npy') | |
if os.path.isfile(path): | |
return np.load(path) | |
else: | |
return None | |
def get_frames_by_filename(self, filename, data_type): | |
"""Given a filename for one of our samples, return the data | |
the model needs to make predictions.""" | |
X = [] | |
y = [] | |
# First, find the sample row. | |
sample = None | |
for row in self.train_data: | |
if row[0] == filename: | |
sample = row | |
break | |
else: | |
for row in self.test_data: | |
if row[0] == filename: | |
sample = row | |
break | |
if sample is None: | |
raise ValueError("Couldn't find sample: %s" % filename) | |
if data_type == "images": | |
filename, cls_id, frames_list = sample | |
frames = random.choice(frames_list) | |
# Build the image sequence | |
X.append(self.build_image_sequence( | |
frames, crop_type='random')) | |
y.append(cls_id) | |
else: | |
# Get the sequence from disk. | |
sequence = self.get_extracted_sequence(data_type, row) | |
if sequence is None: | |
print("Can't find sequence. Did you generate them?") | |
raise | |
X.append(sequence) | |
y.append(cls_id) | |
return np.array(X), to_categorical(np.array(y), self.num_classes) | |
def print_class_from_prediction(self, predictions, nb_to_return=5): | |
"""Given a prediction, print the top classes.""" | |
label_predictions = {} | |
for cls_id, cls_name in self.actions.items(): | |
label_predictions[cls_name] = predictions[cls_id] | |
sorted_lps = sorted( | |
label_predictions.items(), | |
key=operator.itemgetter(1), | |
reverse=True) | |
# And return the top N. | |
for i, class_prediction in enumerate(sorted_lps): | |
if i > nb_to_return - 1 or class_prediction[1] == 0.0: | |
break | |
print("%s: %.2f" % (class_prediction[0], class_prediction[1])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment