Skip to content

Instantly share code, notes, and snippets.

@twmht
Last active October 20, 2023 11:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save twmht/d719f7e4b3ccaa8b269b7dd7475e023d to your computer and use it in GitHub Desktop.
Save twmht/d719f7e4b3ccaa8b269b7dd7475e023d to your computer and use it in GitHub Desktop.
# from tvm.contrib.torch import optimize_torch
import tvm.tir.tensor_intrin
import contextlib
import tempfile
import tvm
import onnx
from tvm import meta_schedule as ms
from tvm import relay
def get_network(weight, batch_size, layout="NHWC", dtype="float32", use_sparse=False):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
onnx_model = onnx.load(weight)
input_name = "input"
shape_dict = {input_name: input_shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
desired_layouts = {'nn.conv2d': ['NHWC', 'default'], 'image.resize2d': ['NHWC'], 'nn.upsampling': ['NHWC']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(), relay.transform.ConvertLayout(desired_layouts)])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
mod = tvm.IRModule.from_expr(mod["main"])
mod = tvm.relay.transform.FastMath()(mod)
mod = tvm.relay.transform.EliminateCommonSubexpr()(mod)
BindPass = tvm.relay.transform.function_pass(lambda fn, new_mod, ctx: tvm.relay.build_module.bind_params_by_name(fn, params), opt_level=1)
mod = BindPass(mod)
mod = tvm.relay.transform.FoldConstant()(mod)
mod = tvm.relay.transform.CombineParallelBatchMatmul()(mod)
mod = tvm.relay.transform.FoldConstant()(mod)
mod = tvm.relay.transform.InferType()(mod)
mod = tvm.relay.transform.ToMixedPrecision()(mod)
mod = tvm.relay.transform.EliminateCommonSubexpr()(mod)
mod = tvm.relay.transform.FoldConstant()(mod)
mod = tvm.relay.transform.CombineParallelBatchMatmul()(mod)
mod = tvm.relay.transform.FoldConstant()(mod)
return mod, params, input_shape
# weight = '/home/acer/rtmdet_m_syncbn_fast_8xb32-300e_coco_640_640_best_coco_crocs_precision_epoch_218.onnx'
weight = '/home/acer/tvm_experiment/resnet50.onnx'
work_dir = '/home/acer/test_meta_tensorcore'
batch_size = 1
layout = 'NHWC'
dtype = "float16"
use_sparse = False
mod, params, input_shape = get_network(
weight,
batch_size,
layout,
dtype=dtype,
use_sparse=use_sparse,
)
if work_dir:
context_manager = contextlib.nullcontext(work_dir)
else:
context_manager = tempfile.TemporaryDirectory()
target = tvm.target.Target("nvidia/rtx-3000")
space=ms.space_generator.PostOrderApply(
sch_rules="cuda-tensorcore",
postprocs="cuda-tensorcore",
mutator_probs="cuda-tensorcore",
)
with context_manager as work_dir: # pylint: disable=redefined-argument-from-local
# database = ms.relay_integration.tune_relay(
# database = ms.database.Database.create(kind="json", work_dir=work_dir)
# database = ms.tir_integration.tune_tir(
database = ms.relay_integration.tune_relay(
mod=mod,
params=params,
target=target,
work_dir=work_dir,
#22ms
max_trials_global=25000,
# max_trials_per_task=64,
# max_trials_global=25000,
max_trials_per_task=256,
num_trials_per_iter=64,
builder='local',
runner='local',
database='json',
cost_model='xgb',
measure_callbacks='default',
task_scheduler='gradient',
# space = 'cuda',
space = space,
strategy="evolutionary",
seed=None
)
with database, tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_meta_schedule": True},
):
lib = relay.build(mod, target=target, params=params)
lib.export_library('/home/acer/meta_resnet50.tar')
# lib.export_library('/home/acer/rtmdet_m_syncbn_fast_8xb32-300e_coco_640_640_best_coco_crocs_precision_epoch_218_meta.tar')
# executor_factory = ms.relay_integration.compile_relay(
# database=database,
# mod=mod,
# target=target,
# params=params,
# backend="graph",
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment