Skip to content

Instantly share code, notes, and snippets.

@geohot
Created September 28, 2017 00:16
Show Gist options
  • Save geohot/3232a35f439ec17d389e87346fe30abe to your computer and use it in GitHub Desktop.
Save geohot/3232a35f439ec17d389e87346fe30abe to your computer and use it in GitHub Desktop.
import cv2
import numpy as np
from tqdm import tqdm
def test_frame_gen(limit=None):
fq = []
cap = cv2.VideoCapture("data/test.mp4")
cnt = 0
while 1:
ret, frame = cap.read()
if frame is None:
break
frm = frame.mean(axis=2)/256.0
frm = frm[196:-196, 232:-232]
frm = cv2.resize(frm, (320, 160))
fq.append(frm)
fq = fq[-3:]
if len(fq) == 3:
fqn = np.array(fq).swapaxes(0,2).swapaxes(0,1)
yield fqn
cnt += 1
if limit is not None and cnt == limit:
return
if __name__ == "__main__":
# internal only: load speednet model trained on chffr data
from tools.reporter.reporter import load_model_from_server
m = load_model_from_server("55f76936-98f1-4f19-bda4-54752f67719a", 160)
frms = list(tqdm(test_frame_gen()))
frms = frms[0:1] + frms + frms[-1:]
preds = m.predict(np.array(frms), batch_size=256, verbose=1)
with open("test_pred.txt", "w") as f:
f.write('\n'.join(map(lambda x: str(x[0]), preds)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment