Skip to content

Instantly share code, notes, and snippets.

@RomanSteinberg
Last active September 6, 2019 12:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RomanSteinberg/f9405fa02f3fdf98d3c7e24550c5336b to your computer and use it in GitHub Desktop.
Save RomanSteinberg/f9405fa02f3fdf98d3c7e24550c5336b to your computer and use it in GitHub Desktop.
test_qt_vs_mx
import cv2
import sys
import numpy as np
import mxnet as mx
from PyQt5.QtWidgets import QApplication
QApplication(sys.argv) # comment it
def make_input():
img = cv2.imread('face.jpg') # aligned
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img, (2,0,1))
input_blob = np.expand_dims(img, axis=0)
data = mx.nd.array(input_blob)
db = mx.io.DataBatch(data=(data,))
return db
def fit(model, data):
print('------- Input (first 10 numbers) -------')
print(data.data[0][0, 0, 0, :10])
model.forward(data, is_train=False)
ret = model.get_outputs()[0].asnumpy()
print('------- Result (first 10 numbers) -------')
print(ret[0,0,0,:10])
def get_model():
sym, arg_params, aux_params = mx.model.load_checkpoint('model/model', 0)
all_layers = sym.get_internals()
sym = all_layers['_minusscalar0_output']
print('------- Output layer info -------')
print(sym.list_arguments())
print(sym.list_outputs())
print(sym.list_attr())
print(sym.debug_str())
model = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names = None)
model.bind(data_shapes=[('data', (1, 3, 112, 112))])
model.set_params(arg_params, aux_params)
return model
model = get_model()
db = make_input()
fit(model, db)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment