Created
September 28, 2017 00:16
-
-
Save geohot/3232a35f439ec17d389e87346fe30abe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
from tqdm import tqdm | |
def test_frame_gen(limit=None): | |
fq = [] | |
cap = cv2.VideoCapture("data/test.mp4") | |
cnt = 0 | |
while 1: | |
ret, frame = cap.read() | |
if frame is None: | |
break | |
frm = frame.mean(axis=2)/256.0 | |
frm = frm[196:-196, 232:-232] | |
frm = cv2.resize(frm, (320, 160)) | |
fq.append(frm) | |
fq = fq[-3:] | |
if len(fq) == 3: | |
fqn = np.array(fq).swapaxes(0,2).swapaxes(0,1) | |
yield fqn | |
cnt += 1 | |
if limit is not None and cnt == limit: | |
return | |
if __name__ == "__main__": | |
# internal only: load speednet model trained on chffr data | |
from tools.reporter.reporter import load_model_from_server | |
m = load_model_from_server("55f76936-98f1-4f19-bda4-54752f67719a", 160) | |
frms = list(tqdm(test_frame_gen())) | |
frms = frms[0:1] + frms + frms[-1:] | |
preds = m.predict(np.array(frms), batch_size=256, verbose=1) | |
with open("test_pred.txt", "w") as f: | |
f.write('\n'.join(map(lambda x: str(x[0]), preds))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment