from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import ( | |
Input, | |
Conv2D, | |
Activation, | |
Reshape, | |
Flatten, | |
Lambda, | |
Dense, | |
) | |
from tensorflow.keras.callbacks import ModelCheckpoint | |
import tensorflow.keras.backend as K | |
# Scipy | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.utils import shuffle | |
# Open CV | |
import cv2 | |
# Settings | |
IMG_SHAPE = (128, 128, 3) | |
FILTERS = 16 | |
DEPTH = 0 | |
KERNEL = 8 | |
BATCH_SIZE = 32 | |
''' | |
Encoder | |
- Encodes the image | |
''' | |
encoder_input = Input(shape=IMG_SHAPE) | |
encoder = Conv2D(FILTERS, KERNEL, activation='relu', padding='same', name='encoder_conv_0')(encoder_input) | |
for i in range(DEPTH): | |
encoder = Conv2D(FILTERS, KERNEL, activation='relu', padding='same', name=f'encoder_conv_{i + 1}')(encoder) | |
# encoder = MaxPooling2D(2)(encoder) | |
attention_conv = Conv2D(1, KERNEL, activation='relu', padding='same', name='attention_conv')(encoder) | |
attention_flatten = Flatten(name='attention_flatten')(attention_conv) | |
# attention_dense = Dense(64 * 64, activation='relu')(attention_flatten) | |
attention_softmax = Activation('softmax', name='attention_softmax')(attention_flatten) | |
attention_reshape = Reshape((IMG_SHAPE[0], IMG_SHAPE[1], 1), name='attention_reshape')(attention_softmax) | |
attention_output = Lambda(lambda x : x[0] * x[1], name='attention_output')([encoder_input, attention_reshape]) | |
classifier1_flatten = Flatten(name='classifier1_flatten')(attention_reshape) | |
classifier2_flatten = Flatten(name='classifier2_flatten')(attention_output) | |
# classifier = Dense(1, activation='sigmoid', name='classifier')() | |
classifier1 = Lambda(lambda x : K.max(x, axis=-1), name='classifier1')(classifier1_flatten) | |
classifier2 = Dense(1, activation='sigmoid', name='classifier2')(classifier2_flatten) | |
decoder = Conv2D(FILTERS, KERNEL, activation='relu', padding='same', name='decoder')(attention_reshape) | |
decoder = Conv2D(IMG_SHAPE[2], KERNEL, activation='relu', padding='same', name='decoder_output')(decoder) | |
model = Model(encoder_input, [classifier1, classifier2, decoder]) | |
model.summary() | |
model.compile( | |
loss=['binary_crossentropy', 'binary_crossentropy', 'mse'], | |
loss_weights=[1000, 1000, 1], | |
optimizer='adam', | |
metrics=['accuracy'] | |
) | |
# Training | |
def load_samples(): | |
positive = cv2.VideoCapture('1.mp4') | |
negative = cv2.VideoCapture('0.mp4') | |
xs, y1s, y2s = [], [], [] | |
frame_history_pos = [np.zeros((IMG_SHAPE[0], IMG_SHAPE[1], 3))] | |
frame_history_neg = [np.zeros((IMG_SHAPE[0], IMG_SHAPE[1], 3))] | |
while True: | |
found_pos, next_pos = positive.read() | |
found_neg, next_neg = negative.read() | |
if found_pos: | |
next_pos = cv2.resize(next_pos, (IMG_SHAPE[0], IMG_SHAPE[1])) | |
next_pos = np.array(next_pos) | |
x = np.array(next_pos) | |
frame_history_pos.append(np.array(next_pos)) | |
if len(frame_history_pos) > 10: | |
del(frame_history_pos[0]) | |
next_pos = np.abs(next_pos - np.array(frame_history_pos).mean(axis=0)) | |
xs.append(next_pos) | |
y1s.append(1) | |
y2s.append(next_pos) | |
if found_neg: | |
next_neg = cv2.resize(next_neg, (IMG_SHAPE[0], IMG_SHAPE[1])) | |
next_neg = np.array(next_neg) | |
frame_history_neg.append(np.array(next_neg)) | |
if len(frame_history_neg) > 10: | |
del(frame_history_neg[0]) | |
next_neg = np.abs(next_neg - np.array(frame_history_neg).mean(axis=0)) | |
xs.append(next_neg) | |
y1s.append(0) | |
y2s.append(next_neg) | |
xs, y1s, y2s = shuffle(xs, y1s, y2s) | |
yield np.array(xs[:BATCH_SIZE]), [np.array(y1s[:BATCH_SIZE]), np.array(y1s[:BATCH_SIZE]), np.array(y2s[:BATCH_SIZE])] | |
if __name__ == '__main__': | |
try: | |
model.load_weights('checkpoint.h5') | |
except Exception as e: | |
print(e) | |
model.fit_generator( | |
generator=load_samples(), | |
steps_per_epoch=100, | |
epochs=1000, | |
callbacks=[ModelCheckpoint('checkpoint.h5')] | |
) | |
model.save_weights('checkpoint.h5') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment