Skip to content

Instantly share code, notes, and snippets.

@maulvi
Last active July 18, 2023 11:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maulvi/e8e845d87a0a88abe89a3bc465c56e33 to your computer and use it in GitHub Desktop.
Save maulvi/e8e845d87a0a88abe89a3bc465c56e33 to your computer and use it in GitHub Desktop.
wget -q https://gist.github.com/maulvi/e8e845d87a0a88abe89a3bc465c56e33/raw/5342a5650e6450882b7e28158802bce4018d38ae/model_zoo.py -O /usr/local/lib/python3.10/dist-packages/insightface/model_zoo/model_zoo.py
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
#wget /usr/local/lib/python3.10/dist-packages/insightface/model_zoo/model_zoo.py
import os
import os.path as osp
import glob
import onnxruntime
from .arcface_onnx import *
from .scrfd import *
#__all__ = ['get_model', 'get_model_list', 'get_arcface_onnx', 'get_scrfd']
__all__ = ['get_model']
class ModelRouter:
def __init__(self, onnx_file):
self.onnx_file = onnx_file
def get_model(self):
session = onnxruntime.InferenceSession(self.onnx_file, providers=[ 'CUDAExecutionProvider', 'CPUExecutionProvider'])
input_cfg = session.get_inputs()[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
#print(input_shape)
if len(outputs)>=5:
return SCRFD(model_file=self.onnx_file, session=session)
elif input_shape[2]==112 and input_shape[3]==112:
return ArcFaceONNX(model_file=self.onnx_file, session=session)
else:
raise RuntimeError('error on model routing')
def find_onnx_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.onnx" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_model(name, **kwargs):
root = kwargs.get('root', '~/.insightface/models')
root = os.path.expanduser(root)
if not name.endswith('.onnx'):
model_dir = os.path.join(root, name)
model_file = find_onnx_file(model_dir)
if model_file is None:
return None
else:
model_file = name
assert osp.isfile(model_file), 'model should be file'
router = ModelRouter(name)
model = router.get_model()
#print('get-model for ', name,' : ', model.taskname)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment