-
-
Save Wheest/a05837c1ae257d177ef5431ef5ac4768 to your computer and use it in GitHub Desktop.
TVM sparse DNN example with auto-tuning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import tvm | |
import argparse | |
from tvm import te, relay, auto_scheduler | |
from tvm.contrib.download import download_testdata | |
from tvm.contrib import graph_executor as graph_runtime | |
from tvm.contrib import graph_executor | |
from tvm.relay import data_dep_optimization as ddo | |
import logging | |
import os | |
import timeit | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import TensorDataset | |
from torchvision import transforms | |
import torch.nn.utils.prune as prune | |
from typing import Optional | |
batch_size = 1 | |
target = "llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2" | |
dev = tvm.device(target) | |
np.random.seed(42) | |
INPUT_SIZE = 128 | |
class WeeNet(nn.Module): | |
def __init__(self): | |
super(WeeNet, self).__init__() | |
self.layer1 = nn.Conv2d( | |
3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0), bias=False | |
) | |
def forward(self, x): | |
out = F.relu(self.layer1(x)) | |
return out | |
def sparsify_model(mod, params): | |
# convert to sparse if relevant | |
bs_r = 1 | |
mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) | |
mod, params = ddo.bsr_conv2d.convert( | |
mod, | |
params, | |
blocksize=(bs_r, 1), | |
sparsity_threshold=0.05, | |
layout="NCHW", | |
) | |
return tvm.IRModule.from_expr(mod), params | |
def get_model(): | |
model = WeeNet() | |
input_shape = (batch_size, 3, INPUT_SIZE, INPUT_SIZE) | |
input_data = torch.randint(100, input_shape, dtype=torch.float32) | |
for module in [model.layer1]: | |
prune.random_unstructured(module, name="weight", amount=0.7) | |
prune.remove(module, "weight") | |
scripted_model = torch.jit.trace(model, input_data).eval() | |
y = scripted_model(input_data) | |
mod, params = relay.frontend.from_pytorch(scripted_model, [("data", input_shape)]) | |
mod, params = sparsify_model(mod, params) | |
print(mod) | |
x = input_data.numpy() | |
y = y.detach().numpy() | |
return mod, params, x, y | |
def auto_schedule(trials=200): | |
mod, params, x, y_target = get_model() | |
log_file = "/tmp/log_file.json" | |
if os.path.exists(log_file): | |
os.remove(log_file) | |
tasks, task_weights = auto_scheduler.extract_tasks(mod, params, target) | |
if len(tasks) == 0: | |
print("Hey we do not have any tasks...") | |
r_tsks, r_weights = [], [] | |
for idx, task in enumerate(tasks): | |
print( | |
"========== Task %d (workload key: %s) ==========" | |
% (idx, task.workload_key) | |
) | |
print(task.desc) | |
print(task.compute_dag) | |
if "conv2d" not in task.desc: | |
r_tsks.append(task) | |
r_weights.append(task_weights[idx]) | |
for t, w in zip(r_tsks, r_weights): | |
tasks.remove(t) | |
task_weights.remove(w) | |
tuner = auto_scheduler.TaskScheduler(tasks, task_weights) | |
tune_option = auto_scheduler.TuningOptions( | |
num_measure_trials=trials, # change this to 20000 to achieve the best performance | |
runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True), | |
measure_callbacks=[auto_scheduler.RecordToFile(log_file)], | |
) | |
from tvm.topi.sparse.utils import sparse_sketch_rules | |
search_policy = [ | |
auto_scheduler.SketchPolicy( | |
task, | |
program_cost_model=auto_scheduler.XGBModel(), | |
init_search_callbacks=sparse_sketch_rules(), | |
) | |
for task in tasks | |
] | |
tuner.tune(tune_option, search_policy=search_policy) | |
sch, args = tasks[0].apply_best(log_file) | |
print("Lowered TIR:") | |
print(tvm.lower(sch, args, simple_mode=True)) | |
with auto_scheduler.ApplyHistoryBest(log_file): | |
with tvm.transform.PassContext( | |
opt_level=3, config={"relay.backend.use_auto_scheduler": True} | |
): | |
lib = relay.build(mod, target=target, params=params) | |
module = graph_executor.GraphModule(lib["default"](dev)) | |
dtype = "float32" | |
data = x | |
data_tvm = tvm.nd.array((data).astype(dtype)) | |
module.set_input("data", data_tvm) | |
print("Ansor:") | |
print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) | |
print() | |
def regular(): | |
mod, params, x, y_target = get_model() | |
with tvm.transform.PassContext(opt_level=3): | |
lib = relay.build(mod, target=target, params=params) | |
m = graph_executor.GraphModule(lib["default"](dev)) | |
dtype = "float32" | |
data_tvm = tvm.nd.array((x).astype(dtype)) | |
m.set_input("data", data_tvm) | |
m.run() | |
y = m.get_output(0).numpy() | |
print("Untuned:") | |
print(m.benchmark(dev, repeat=3, min_repeat_ms=500)) | |
print() | |
def main(args): | |
if args.mode == "regular": | |
regular() | |
else: | |
auto_schedule(100) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run a model") | |
parser.add_argument( | |
"mode", | |
default="regular", | |
choices=["regular", "ansor"], | |
) | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment