Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import argparse
from sys import platform
from models import *
from utils.datasets import *
from utils.utils import *
def export():
img_size = (320, 192)
weights = opt.weights
device = 'cpu'
# Initialize model
model = Darknet(opt.cfg, img_size)
# Load weights
attempt_download(weights)
if weights.endswith('.pt'): # pytorch format
model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format
load_darknet_weights(model, weights)
# Eval mode
model.eval()
# Fuse Conv2d + BatchNorm2d layers
model.fuse()
img = torch.zeros((1, 3) + img_size) # (1, 3, 320, 192)
traced = torch.jit.trace(model, [img])
torch.jit.save(traced, opt.outfile)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='*.cfg path')
parser.add_argument('--weights', type=str, default='weights/yolov3-spp-ultralytics.pt', help='weights path')
parser.add_argument('--outfile', type=str, default='yolov3-spp-traced.pt', help='exported file path')
opt = parser.parse_args()
print(opt)
with torch.no_grad():
export()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment