-
-
Save Wheest/033507e6319b06047bae33beab6c3958 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import tvm | |
import time | |
import numpy as np | |
from tvm.contrib import graph_runtime | |
from tvm.relay import data_dep_optimization as ddo | |
import onnx | |
import itertools | |
import scipy.sparse as sp | |
# network definition | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import torch | |
import torch.nn as nn | |
class WeeNet(nn.Module): | |
def __init__(self, in_size, num_units): | |
super(WeeNet, self).__init__() | |
self.layer1 = nn.Linear(in_size, num_units) | |
def forward(self, x): | |
out = self.layer1(x) | |
return out | |
torch.manual_seed(0) | |
import torch | |
in_v = 100 | |
out_s = 101 | |
input_shape = (1, in_v) | |
input_data = np.zeros(input_shape, dtype=float) | |
input_data = np.arange(input_data.size).reshape(input_shape) | |
x = Variable(torch.from_numpy(input_data)).float() | |
model = WeeNet(in_v, out_s) | |
kernel_shape = (101, 100) | |
kernels = np.zeros(kernel_shape) | |
from scipy import sparse | |
kernels = sparse.rand(in_v, out_s, density=0.2) | |
kernels = np.squeeze(np.asarray(kernels.todense())) | |
kernels *= 10 | |
kernels = np.round(kernels, 0) | |
kernels = kernels.reshape(kernel_shape) | |
state_dict = model.state_dict() | |
state_dict['layer1.weight'] = torch.from_numpy(kernels).float() | |
model.load_state_dict(state_dict, strict=True) | |
from scipy.sparse import csr_matrix | |
kernels_sp = csr_matrix(kernels.reshape(kernels.shape[0], -1)) | |
py_out = model(x) | |
save_name = 'fcnet.onnx' | |
input_names = ['input_1'] | |
torch.onnx.export(model, x, save_name, input_names=input_names) | |
from tvm import relay | |
def import_onnx(name, shape_dict): | |
model = onnx.load(name) | |
mod, params = relay.frontend.from_onnx(model, shape_dict) | |
return mod, params, shape_dict | |
def run_relay_graph(mod, params, shape_dict, input_data, target, ctx): | |
with relay.build_config(opt_level=3): | |
lib = relay.build(mod, target=target, params=params) | |
input_shape = shape_dict["input_1"] | |
m = graph_runtime.GraphModule(lib['default'](ctx)) | |
m.set_input(0, input_data) | |
m.run() | |
tvm_output = m.get_output(0) | |
ftimer = m.module.time_evaluator("run", ctx, repeat=5, number=5) | |
prof_res = np.array(ftimer().results) * 1000 | |
return tvm_output | |
def run_dense(mod, params, shape_dict, input_data, target, ctx): | |
return run_relay_graph(mod, params, shape_dict, target, ctx) | |
target = "opencl" | |
ctx = tvm.cl(0) | |
mod, params, shape_dict = import_onnx(save_name, {'input_1': input_shape}) | |
true_outs = run_relay_graph(mod, params, shape_dict, input_data, target, ctx) | |
# run sparse model | |
mod, params, shape_dict = import_onnx(save_name, {'input_1': input_shape}) | |
modt, paramst = ddo.simplify_fc_transpose.convert(mod["main"], params) | |
mods, paramss = ddo.bsr_dense.convert(modt, paramst, blocksize=(1,1), sparsity_threshold=0.0) | |
out = run_relay_graph(mods, paramss, shape_dict, input_data, target, ctx) | |
np.testing.assert_allclose(true_outs.asnumpy(), out.asnumpy(), rtol=1e-5, atol=0) | |
print("done!") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment